Commit 83fa8892 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Add ability to specify a different initialisation to the prior

parent 889b0ab3
......@@ -350,22 +350,27 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
class MCMCGlitchSearch(BaseSearchClass):
""" MCMC search using the SemiCoherentGlitchSearch """
@initializer
def __init__(self, label, outdir, sftlabel, sftdir, theta, tref, tstart,
tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1,
nglitch=0, minCoverFreq=None, maxCoverFreq=None,
scatter_val=1e-4, betas=None, detector=None,
dtglitchmin=20*86400, earth_ephem=None, sun_ephem=None):
def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref,
tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1,
nglitch=0, theta_initial=None, minCoverFreq=None,
maxCoverFreq=None, scatter_val=1e-4, betas=None,
detector=None, dtglitchmin=20*86400, earth_ephem=None,
sun_ephem=None):
"""
Parameters
label, outdir: str
A label and directory to read/write data from/to
sftlabel, sftdir: str
A label and directory in which to find the relevant sft file
theta: dict
theta_prior: dict
Dictionary of priors and fixed values for the search parameters.
For each parameters (key of the dict), if it is to be held fixed
the value should be the constant float, if it is be searched, the
value should be a dictionary of the prior.
theta_initial: dict, array, (None)
Either a dictionary of distribution about which to distribute the
initial walkers about, an array (from which the walkers will be
scattered by scatter_val, or None in which case the prior is used.
nglitch: int
The number of glitches to allow
tref, tstart, tend: int
......@@ -449,12 +454,10 @@ class MCMCGlitchSearch(BaseSearchClass):
[[gs]*self.nglitch for gs in glitch_symbols]).flatten())
full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
r'$\delta$'] + full_glitch_symbols)
self.theta_prior = {}
self.theta_keys = []
fixed_theta_dict = {}
for key, val in self.theta.iteritems():
for key, val in self.theta_prior.iteritems():
if type(val) is dict:
self.theta_prior[key] = val
fixed_theta_dict[key] = 0
if key in glitch_keys:
for i in range(self.nglitch):
......@@ -627,33 +630,6 @@ class MCMCGlitchSearch(BaseSearchClass):
ax2.plot(x, [prior(xi) for xi in x], '-r')
ax.set_xlim(xlim)
def get_new_p0(self, sampler, scatter_val=1e-3):
""" Returns new initial positions for walkers are burn0 stage
This returns new positions for all walkers by scattering points about
the maximum posterior with scale `scatter_val`.
"""
if sampler.chain[:, :, -1, :].shape[0] == 1:
ntemps_temp = 1
else:
ntemps_temp = self.ntemps
pF = sampler.chain[:, :, -1, :].reshape(
ntemps_temp, self.nwalkers, self.ndim)[0, :, :]
lnp = sampler.lnprobability[:, :, -1].reshape(
self.ntemps, self.nwalkers)[0, :]
if any(np.isnan(lnp)):
logging.warning("The sampler has produced nan's")
p = pF[np.nanargmax(lnp)]
p0 = [[p + scatter_val * p * np.random.randn(self.ndim)
for i in xrange(self.nwalkers)] for j in xrange(self.ntemps)]
if self.nglitch > 1:
p0 = np.array(p0)
p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:],
axis=2)
return p0
def Generic_lnprior(self, **kwargs):
""" Return a lambda function of the pdf
......@@ -750,18 +726,63 @@ class MCMCGlitchSearch(BaseSearchClass):
return fig, axes
def _generate_scattered_p0(self, p):
""" Generate a set of p0s scattered about p """
p0 = [[p + scatter_val * p * np.random.randn(self.ndim)
for i in xrange(self.nwalkers)]
for j in xrange(self.ntemps)]
return p0
def _sort_p0_times(self, p0):
p0 = np.array(p0)
p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:], axis=2)
return p0
def GenerateInitial(self):
""" Generate a set of init vals for the walkers based on the prior """
p0 = [[[self.GenerateRV(**self.theta_prior[key])
for key in self.theta_keys]
for i in range(self.nwalkers)]
for j in range(self.ntemps)]
""" Generate a set of init vals for the walkers """
if type(self.theta_initial) == dict:
p0 = [[[self.GenerateRV(**self.theta_initial[key])
for key in self.theta_keys]
for i in range(self.nwalkers)]
for j in range(self.ntemps)]
elif self.theta_initial is None:
p0 = [[[self.GenerateRV(**self.theta_prior[key])
for key in self.theta_keys]
for i in range(self.nwalkers)]
for j in range(self.ntemps)]
elif len(self.theta_initial) == self.ndim:
p0 = self._generate_scattered_p0(self.theta_initial)
else:
raise ValueError('theta_initial not understood')
if self.nglitch > 1:
p0 = self._sort_p0_times(p0)
return p0
def get_new_p0(self, sampler, scatter_val=1e-3):
""" Returns new initial positions for walkers are burn0 stage
This returns new positions for all walkers by scattering points about
the maximum posterior with scale `scatter_val`.
"""
if sampler.chain[:, :, -1, :].shape[0] == 1:
ntemps_temp = 1
else:
ntemps_temp = self.ntemps
pF = sampler.chain[:, :, -1, :].reshape(
ntemps_temp, self.nwalkers, self.ndim)[0, :, :]
lnp = sampler.lnprobability[:, :, -1].reshape(
self.ntemps, self.nwalkers)[0, :]
if any(np.isnan(lnp)):
logging.warning("The sampler has produced nan's")
p = pF[np.nanargmax(lnp)]
p0 = self._generate_scattered_p0(p)
# Order the times to start the right way around
if self.nglitch > 1:
p0 = np.array(p0)
p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:],
axis=2)
p0 = self._sort_p0_times(p0)
return p0
def get_save_data_dictionary(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment