Skip to content
Snippets Groups Projects
Commit cc9e3762 authored by Rutger van Haasteren's avatar Rutger van Haasteren
Browse files

Added some documentation in the source file

parent 6279285f
Branches
No related tags found
No related merge requests found
# prior_wrapper.py
"""Classes and functions that can wrap an Enterprise pta object for HBMs
Simple usage instructions
-------------------------
Build your Enterprise PTA object as you normally would. Make sure you know what
the parameter names are for the signals that you want to place a hierarchical
prior on. Those parameter names will be matched with a regular expression.
For instance, for powerlaw red noise, often the parameters would be named:
'J0030+0451_red_noise_log10_A'
'J0030+0451_red_noise_gamma'
'J1713+0747_red_noise_log10_A',
...
With the above naming convention, you could create an HBM like this:
import prior_wrapper as pw
wrapper = pw.EnterpriseWrapper(
pta=pta,
hyper_regexps = {
'red_noise': {
'log10_amp': '_red_noise_log10_A$',
'gamma': '_red_noise_gamma$',
'prior': pw.BoundedMvNormalPlHierarchicalPrior,
}
}
)
x0 = wrapper.sample()
cov = np.diag(0.01 * np.ones_like(x0))
sampler = ptmcmc(len(x0), wrapper.log_likelihood, wrapper.log_prior, cov, outDir='./chains', resume=False)
# parameter names are in:
# wrapper.param_names
# Add the prior draws
for draw_function in wrapper.get_draw_from_prior_functions():
sampler.addProposalToCycle(draw_function, 5)
# If you also want to model DMGP with a Hierarchical Prior, you could do:
wrapper = pw.EnterpriseWrapper(
pta=pta,
hyper_regexps = {
'red_noise': {
'log10_amp': '_red_noise_log10_A$',
'gamma': '_red_noise_gamma$',
'prior': pw.BoundedMvNormalPlHierarchicalPrior,
},
'dmgp': {
'log10_amp': '_dm_gp_log10_A$',
'gamma': '_dm_gp_gamma$',
'prior': pw.BoundedMvNormalPlHierarchicalPrior,
},
}
)
# This last one is largely untested. Since the prior ranges for the
# hyperparameters cannot be set dynamically at the moment, we may need to update
# the code a bit to make it work for more general priors
"""
import numpy as np
import scipy.stats as sstats
import scipy.linalg as sl
import re
from enterprise.signals.parameter import function
class kumaraswamy_distribution(sstats.rv_continuous):
"""Kumaraswamy distribution like for scipy"""
def _pdf(self, x, a, b):
# Adding a condition to ensure x is within the [0,1] interval
return np.where((x >= 0) & (x <= 1), a * b * x**(a-1) * (1 - x**a)**(b-1), 0)
def _cdf(self, x, a, b):
......@@ -107,31 +173,6 @@ class IntervalTransform(object):
num = 1 + np.exp(p)
return np.log(self._upper-self._lower) + np.log(1/num - 1/(num**2))
def ptmcmc_to_inferencedata(chain, params):
"""Inferencedata, used by Arviz, can be created from a chain like this"""
import arviz as az
import xarray as xr
samples = chain.reshape(1, -1, chain.shape[1])
datasets = {}
for i, name in enumerate(params):
data_array = xr.DataArray(samples[:, :, i], dims=("chain", "draw"), name=name)
datasets[name] = data_array
dataset = xr.merge(datasets.values())
return az.convert_to_inference_data(dataset)
def inferencedata_to_ptmcmc(idata):
"""InferenceData to ptmcmc chain"""
# Get parameter names
param_names = list(idata.posterior.data_vars)
# Get MCMC chains
posterior = idata.posterior
chains = {var_name: posterior[var_name].values for var_name in param_names}
return chains, param_names
def log_sum_exp(log_prior1, log_prior2):
"""Take the log of two exponential sums, stable numerically"""
max_log_prior = np.maximum(log_prior1, log_prior2)
......@@ -166,6 +207,12 @@ def ptapar_mapping(pta):
return ptapar_to_array, array_to_ptapar
@function
def powerlaw_flat_tail(f, log10_A=-16, gamma=5, log10_kappa=-7, components=2):
df = np.diff(np.concatenate((np.array([0]), f[::components])))
return (
(10**log10_A) ** 2 / 12.0 / np.pi**2 * const.fyr ** (gamma - 3) * f ** (-gamma) * np.repeat(df, components) + 10 ** (2*log10_kappa)
)
class BoundedMvNormalPlHierarchicalPrior(object):
"""Class to represent a Bounded MvNormal hierarchical prior for Enterprise Powerlaw Signals"""
......@@ -219,8 +266,8 @@ class BoundedMvNormalPlHierarchicalPrior(object):
def set_hyperpriors(self):
"""Set the hyper parameter priors"""
self._mu_amp = sstats.uniform(loc=-20, scale=7)
#self._mu_gamma = sstats.uniform(loc=0, scale=7)
# TODO: allow for more flexible setting of hyperparameter ranges
self._mu_amp = sstats.uniform(loc=-20, scale=10)
self._mu_gamma = sstats.uniform(loc=-4, scale=8)
self._L_amp = sstats.uniform(loc=0.03, scale=3.47)
self._L_gamma = sstats.uniform(loc=0.03, scale=3.47)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment