Commit 83fa8892 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Add ability to specify a different initialisation to the prior

parent 889b0ab3
...@@ -350,22 +350,27 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): ...@@ -350,22 +350,27 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
class MCMCGlitchSearch(BaseSearchClass): class MCMCGlitchSearch(BaseSearchClass):
""" MCMC search using the SemiCoherentGlitchSearch """ """ MCMC search using the SemiCoherentGlitchSearch """
@initializer @initializer
def __init__(self, label, outdir, sftlabel, sftdir, theta, tref, tstart, def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref,
tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1, tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1,
nglitch=0, minCoverFreq=None, maxCoverFreq=None, nglitch=0, theta_initial=None, minCoverFreq=None,
scatter_val=1e-4, betas=None, detector=None, maxCoverFreq=None, scatter_val=1e-4, betas=None,
dtglitchmin=20*86400, earth_ephem=None, sun_ephem=None): detector=None, dtglitchmin=20*86400, earth_ephem=None,
sun_ephem=None):
""" """
Parameters Parameters
label, outdir: str label, outdir: str
A label and directory to read/write data from/to A label and directory to read/write data from/to
sftlabel, sftdir: str sftlabel, sftdir: str
A label and directory in which to find the relevant sft file A label and directory in which to find the relevant sft file
theta: dict theta_prior: dict
Dictionary of priors and fixed values for the search parameters. Dictionary of priors and fixed values for the search parameters.
For each parameters (key of the dict), if it is to be held fixed For each parameters (key of the dict), if it is to be held fixed
the value should be the constant float, if it is be searched, the the value should be the constant float, if it is be searched, the
value should be a dictionary of the prior. value should be a dictionary of the prior.
theta_initial: dict, array, (None)
Either a dictionary of distribution about which to distribute the
initial walkers about, an array (from which the walkers will be
scattered by scatter_val, or None in which case the prior is used.
nglitch: int nglitch: int
The number of glitches to allow The number of glitches to allow
tref, tstart, tend: int tref, tstart, tend: int
...@@ -449,12 +454,10 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -449,12 +454,10 @@ class MCMCGlitchSearch(BaseSearchClass):
[[gs]*self.nglitch for gs in glitch_symbols]).flatten()) [[gs]*self.nglitch for gs in glitch_symbols]).flatten())
full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$', full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
r'$\delta$'] + full_glitch_symbols) r'$\delta$'] + full_glitch_symbols)
self.theta_prior = {}
self.theta_keys = [] self.theta_keys = []
fixed_theta_dict = {} fixed_theta_dict = {}
for key, val in self.theta.iteritems(): for key, val in self.theta_prior.iteritems():
if type(val) is dict: if type(val) is dict:
self.theta_prior[key] = val
fixed_theta_dict[key] = 0 fixed_theta_dict[key] = 0
if key in glitch_keys: if key in glitch_keys:
for i in range(self.nglitch): for i in range(self.nglitch):
...@@ -627,33 +630,6 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -627,33 +630,6 @@ class MCMCGlitchSearch(BaseSearchClass):
ax2.plot(x, [prior(xi) for xi in x], '-r') ax2.plot(x, [prior(xi) for xi in x], '-r')
ax.set_xlim(xlim) ax.set_xlim(xlim)
def get_new_p0(self, sampler, scatter_val=1e-3):
""" Returns new initial positions for walkers are burn0 stage
This returns new positions for all walkers by scattering points about
the maximum posterior with scale `scatter_val`.
"""
if sampler.chain[:, :, -1, :].shape[0] == 1:
ntemps_temp = 1
else:
ntemps_temp = self.ntemps
pF = sampler.chain[:, :, -1, :].reshape(
ntemps_temp, self.nwalkers, self.ndim)[0, :, :]
lnp = sampler.lnprobability[:, :, -1].reshape(
self.ntemps, self.nwalkers)[0, :]
if any(np.isnan(lnp)):
logging.warning("The sampler has produced nan's")
p = pF[np.nanargmax(lnp)]
p0 = [[p + scatter_val * p * np.random.randn(self.ndim)
for i in xrange(self.nwalkers)] for j in xrange(self.ntemps)]
if self.nglitch > 1:
p0 = np.array(p0)
p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:],
axis=2)
return p0
def Generic_lnprior(self, **kwargs): def Generic_lnprior(self, **kwargs):
""" Return a lambda function of the pdf """ Return a lambda function of the pdf
...@@ -750,18 +726,63 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -750,18 +726,63 @@ class MCMCGlitchSearch(BaseSearchClass):
return fig, axes return fig, axes
def _generate_scattered_p0(self, p):
""" Generate a set of p0s scattered about p """
p0 = [[p + scatter_val * p * np.random.randn(self.ndim)
for i in xrange(self.nwalkers)]
for j in xrange(self.ntemps)]
return p0
def _sort_p0_times(self, p0):
p0 = np.array(p0)
p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:], axis=2)
return p0
def GenerateInitial(self): def GenerateInitial(self):
""" Generate a set of init vals for the walkers based on the prior """ """ Generate a set of init vals for the walkers """
p0 = [[[self.GenerateRV(**self.theta_prior[key])
for key in self.theta_keys] if type(self.theta_initial) == dict:
for i in range(self.nwalkers)] p0 = [[[self.GenerateRV(**self.theta_initial[key])
for j in range(self.ntemps)] for key in self.theta_keys]
for i in range(self.nwalkers)]
for j in range(self.ntemps)]
elif self.theta_initial is None:
p0 = [[[self.GenerateRV(**self.theta_prior[key])
for key in self.theta_keys]
for i in range(self.nwalkers)]
for j in range(self.ntemps)]
elif len(self.theta_initial) == self.ndim:
p0 = self._generate_scattered_p0(self.theta_initial)
else:
raise ValueError('theta_initial not understood')
if self.nglitch > 1:
p0 = self._sort_p0_times(p0)
return p0
def get_new_p0(self, sampler, scatter_val=1e-3):
""" Returns new initial positions for walkers are burn0 stage
This returns new positions for all walkers by scattering points about
the maximum posterior with scale `scatter_val`.
"""
if sampler.chain[:, :, -1, :].shape[0] == 1:
ntemps_temp = 1
else:
ntemps_temp = self.ntemps
pF = sampler.chain[:, :, -1, :].reshape(
ntemps_temp, self.nwalkers, self.ndim)[0, :, :]
lnp = sampler.lnprobability[:, :, -1].reshape(
self.ntemps, self.nwalkers)[0, :]
if any(np.isnan(lnp)):
logging.warning("The sampler has produced nan's")
p = pF[np.nanargmax(lnp)]
p0 = self._generate_scattered_p0(p)
# Order the times to start the right way around
if self.nglitch > 1: if self.nglitch > 1:
p0 = np.array(p0) p0 = self._sort_p0_times(p0)
p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:],
axis=2)
return p0 return p0
def get_save_data_dictionary(self): def get_save_data_dictionary(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment