From 83fa88920361ac305a0b726400485ea61211a167 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Wed, 21 Sep 2016 15:30:33 +0200 Subject: [PATCH] Add ability to specify a different initialisation to the prior --- pyfstat.py | 111 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 66 insertions(+), 45 deletions(-) diff --git a/pyfstat.py b/pyfstat.py index 639f0d1..f27fe96 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -350,22 +350,27 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): class MCMCGlitchSearch(BaseSearchClass): """ MCMC search using the SemiCoherentGlitchSearch """ @initializer - def __init__(self, label, outdir, sftlabel, sftdir, theta, tref, tstart, - tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1, - nglitch=0, minCoverFreq=None, maxCoverFreq=None, - scatter_val=1e-4, betas=None, detector=None, - dtglitchmin=20*86400, earth_ephem=None, sun_ephem=None): + def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref, + tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1, + nglitch=0, theta_initial=None, minCoverFreq=None, + maxCoverFreq=None, scatter_val=1e-4, betas=None, + detector=None, dtglitchmin=20*86400, earth_ephem=None, + sun_ephem=None): """ Parameters label, outdir: str A label and directory to read/write data from/to sftlabel, sftdir: str A label and directory in which to find the relevant sft file - theta: dict + theta_prior: dict Dictionary of priors and fixed values for the search parameters. For each parameters (key of the dict), if it is to be held fixed the value should be the constant float, if it is be searched, the value should be a dictionary of the prior. + theta_initial: dict, array, (None) + Either a dictionary of distribution about which to distribute the + initial walkers about, an array (from which the walkers will be + scattered by scatter_val, or None in which case the prior is used. nglitch: int The number of glitches to allow tref, tstart, tend: int @@ -449,12 +454,10 @@ class MCMCGlitchSearch(BaseSearchClass): [[gs]*self.nglitch for gs in glitch_symbols]).flatten()) full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$', r'$\delta$'] + full_glitch_symbols) - self.theta_prior = {} self.theta_keys = [] fixed_theta_dict = {} - for key, val in self.theta.iteritems(): + for key, val in self.theta_prior.iteritems(): if type(val) is dict: - self.theta_prior[key] = val fixed_theta_dict[key] = 0 if key in glitch_keys: for i in range(self.nglitch): @@ -627,33 +630,6 @@ class MCMCGlitchSearch(BaseSearchClass): ax2.plot(x, [prior(xi) for xi in x], '-r') ax.set_xlim(xlim) - def get_new_p0(self, sampler, scatter_val=1e-3): - """ Returns new initial positions for walkers are burn0 stage - - This returns new positions for all walkers by scattering points about - the maximum posterior with scale `scatter_val`. - - """ - if sampler.chain[:, :, -1, :].shape[0] == 1: - ntemps_temp = 1 - else: - ntemps_temp = self.ntemps - pF = sampler.chain[:, :, -1, :].reshape( - ntemps_temp, self.nwalkers, self.ndim)[0, :, :] - lnp = sampler.lnprobability[:, :, -1].reshape( - self.ntemps, self.nwalkers)[0, :] - if any(np.isnan(lnp)): - logging.warning("The sampler has produced nan's") - - p = pF[np.nanargmax(lnp)] - p0 = [[p + scatter_val * p * np.random.randn(self.ndim) - for i in xrange(self.nwalkers)] for j in xrange(self.ntemps)] - if self.nglitch > 1: - p0 = np.array(p0) - p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:], - axis=2) - return p0 - def Generic_lnprior(self, **kwargs): """ Return a lambda function of the pdf @@ -750,18 +726,63 @@ class MCMCGlitchSearch(BaseSearchClass): return fig, axes + def _generate_scattered_p0(self, p): + """ Generate a set of p0s scattered about p """ + p0 = [[p + scatter_val * p * np.random.randn(self.ndim) + for i in xrange(self.nwalkers)] + for j in xrange(self.ntemps)] + return p0 + + def _sort_p0_times(self, p0): + p0 = np.array(p0) + p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:], axis=2) + return p0 + def GenerateInitial(self): - """ Generate a set of init vals for the walkers based on the prior """ - p0 = [[[self.GenerateRV(**self.theta_prior[key]) - for key in self.theta_keys] - for i in range(self.nwalkers)] - for j in range(self.ntemps)] + """ Generate a set of init vals for the walkers """ + + if type(self.theta_initial) == dict: + p0 = [[[self.GenerateRV(**self.theta_initial[key]) + for key in self.theta_keys] + for i in range(self.nwalkers)] + for j in range(self.ntemps)] + elif self.theta_initial is None: + p0 = [[[self.GenerateRV(**self.theta_prior[key]) + for key in self.theta_keys] + for i in range(self.nwalkers)] + for j in range(self.ntemps)] + elif len(self.theta_initial) == self.ndim: + p0 = self._generate_scattered_p0(self.theta_initial) + else: + raise ValueError('theta_initial not understood') + + if self.nglitch > 1: + p0 = self._sort_p0_times(p0) + return p0 + + def get_new_p0(self, sampler, scatter_val=1e-3): + """ Returns new initial positions for walkers are burn0 stage + + This returns new positions for all walkers by scattering points about + the maximum posterior with scale `scatter_val`. + + """ + if sampler.chain[:, :, -1, :].shape[0] == 1: + ntemps_temp = 1 + else: + ntemps_temp = self.ntemps + pF = sampler.chain[:, :, -1, :].reshape( + ntemps_temp, self.nwalkers, self.ndim)[0, :, :] + lnp = sampler.lnprobability[:, :, -1].reshape( + self.ntemps, self.nwalkers)[0, :] + if any(np.isnan(lnp)): + logging.warning("The sampler has produced nan's") + + p = pF[np.nanargmax(lnp)] + p0 = self._generate_scattered_p0(p) - # Order the times to start the right way around if self.nglitch > 1: - p0 = np.array(p0) - p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:], - axis=2) + p0 = self._sort_p0_times(p0) return p0 def get_save_data_dictionary(self): -- GitLab