diff --git a/prior_wrapper.py b/prior_wrapper.py index 06f40379d49e9a94419ad94ea1bfba3c835754d3..a2579823efcfa2c16f20e6d09e6e56fae5f00bad 100644 --- a/prior_wrapper.py +++ b/prior_wrapper.py @@ -1,12 +1,78 @@ +# prior_wrapper.py +"""Classes and functions that can wrap an Enterprise pta object for HBMs + +Simple usage instructions +------------------------- +Build your Enterprise PTA object as you normally would. Make sure you know what +the parameter names are for the signals that you want to place a hierarchical +prior on. Those parameter names will be matched with a regular expression. +For instance, for powerlaw red noise, often the parameters would be named: +'J0030+0451_red_noise_log10_A' +'J0030+0451_red_noise_gamma' +'J1713+0747_red_noise_log10_A', +... + +With the above naming convention, you could create an HBM like this: + +import prior_wrapper as pw + +wrapper = pw.EnterpriseWrapper( + pta=pta, + hyper_regexps = { + 'red_noise': { + 'log10_amp': '_red_noise_log10_A$', + 'gamma': '_red_noise_gamma$', + 'prior': pw.BoundedMvNormalPlHierarchicalPrior, + } + } +) + +x0 = wrapper.sample() +cov = np.diag(0.01 * np.ones_like(x0)) + +sampler = ptmcmc(len(x0), wrapper.log_likelihood, wrapper.log_prior, cov, outDir='./chains', resume=False) + +# parameter names are in: +# wrapper.param_names + +# Add the prior draws +for draw_function in wrapper.get_draw_from_prior_functions(): + sampler.addProposalToCycle(draw_function, 5) + + +# If you also want to model DMGP with a Hierarchical Prior, you could do: + +wrapper = pw.EnterpriseWrapper( + pta=pta, + hyper_regexps = { + 'red_noise': { + 'log10_amp': '_red_noise_log10_A$', + 'gamma': '_red_noise_gamma$', + 'prior': pw.BoundedMvNormalPlHierarchicalPrior, + }, + 'dmgp': { + 'log10_amp': '_dm_gp_log10_A$', + 'gamma': '_dm_gp_gamma$', + 'prior': pw.BoundedMvNormalPlHierarchicalPrior, + }, + } +) + +# This last one is largely untested. Since the prior ranges for the +# hyperparameters cannot be set dynamically at the moment, we may need to update +# the code a bit to make it work for more general priors + +""" + import numpy as np import scipy.stats as sstats import scipy.linalg as sl import re +from enterprise.signals.parameter import function class kumaraswamy_distribution(sstats.rv_continuous): """Kumaraswamy distribution like for scipy""" def _pdf(self, x, a, b): - # Adding a condition to ensure x is within the [0,1] interval return np.where((x >= 0) & (x <= 1), a * b * x**(a-1) * (1 - x**a)**(b-1), 0) def _cdf(self, x, a, b): @@ -107,31 +173,6 @@ class IntervalTransform(object): num = 1 + np.exp(p) return np.log(self._upper-self._lower) + np.log(1/num - 1/(num**2)) -def ptmcmc_to_inferencedata(chain, params): - """Inferencedata, used by Arviz, can be created from a chain like this""" - import arviz as az - import xarray as xr - samples = chain.reshape(1, -1, chain.shape[1]) - - datasets = {} - for i, name in enumerate(params): - data_array = xr.DataArray(samples[:, :, i], dims=("chain", "draw"), name=name) - datasets[name] = data_array - dataset = xr.merge(datasets.values()) - return az.convert_to_inference_data(dataset) - -def inferencedata_to_ptmcmc(idata): - """InferenceData to ptmcmc chain""" - - # Get parameter names - param_names = list(idata.posterior.data_vars) - - # Get MCMC chains - posterior = idata.posterior - chains = {var_name: posterior[var_name].values for var_name in param_names} - - return chains, param_names - def log_sum_exp(log_prior1, log_prior2): """Take the log of two exponential sums, stable numerically""" max_log_prior = np.maximum(log_prior1, log_prior2) @@ -166,6 +207,12 @@ def ptapar_mapping(pta): return ptapar_to_array, array_to_ptapar +@function +def powerlaw_flat_tail(f, log10_A=-16, gamma=5, log10_kappa=-7, components=2): + df = np.diff(np.concatenate((np.array([0]), f[::components]))) + return ( + (10**log10_A) ** 2 / 12.0 / np.pi**2 * const.fyr ** (gamma - 3) * f ** (-gamma) * np.repeat(df, components) + 10 ** (2*log10_kappa) + ) class BoundedMvNormalPlHierarchicalPrior(object): """Class to represent a Bounded MvNormal hierarchical prior for Enterprise Powerlaw Signals""" @@ -219,8 +266,8 @@ class BoundedMvNormalPlHierarchicalPrior(object): def set_hyperpriors(self): """Set the hyper parameter priors""" - self._mu_amp = sstats.uniform(loc=-20, scale=7) - #self._mu_gamma = sstats.uniform(loc=0, scale=7) + # TODO: allow for more flexible setting of hyperparameter ranges + self._mu_amp = sstats.uniform(loc=-20, scale=10) self._mu_gamma = sstats.uniform(loc=-4, scale=8) self._L_amp = sstats.uniform(loc=0.03, scale=3.47) self._L_gamma = sstats.uniform(loc=0.03, scale=3.47)