Commit 8435f54d authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Splits up the MCMC classes

This makes the MCMCGlitchSearch a subclass of the more general
MCMCSearch
parent 1e111130
......@@ -346,15 +346,14 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
return twoFsegA + twoFsegB
class MCMCGlitchSearch(BaseSearchClass):
""" MCMC search using the SemiCoherentGlitchSearch """
class MCMCSearch(BaseSearchClass):
""" MCMC search using ComputeFstat"""
@initializer
def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref,
tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1,
nglitch=0, theta_initial=None, minCoverFreq=None,
theta_initial=None, minCoverFreq=None,
maxCoverFreq=None, scatter_val=1e-4, betas=None,
detector=None, dtglitchmin=20*86400, earth_ephem=None,
sun_ephem=None):
detector=None, earth_ephem=None, sun_ephem=None):
"""
Parameters
label, outdir: str
......@@ -370,8 +369,6 @@ class MCMCGlitchSearch(BaseSearchClass):
Either a dictionary of distribution about which to distribute the
initial walkers about, an array (from which the walkers will be
scattered by scatter_val, or None in which case the prior is used.
nglitch: int
The number of glitches to allow
tref, tstart, tend: int
GPS seconds of the reference time, start time and end time
nsteps: list (m,)
......@@ -379,9 +376,6 @@ class MCMCGlitchSearch(BaseSearchClass):
give the nburn and nprod of the 'production' run, all entries
before are for iterative initialisation steps (usually just one)
e.g. [1000, 1000, 500].
dtglitchmin: int
The minimum duration (in seconds) of a segment between two glitches
or a glitch and the start/end of the data
nwalkers, ntemps: int
Number of walkers and temperatures
minCoverFreq, maxCoverFreq: float
......@@ -394,12 +388,14 @@ class MCMCGlitchSearch(BaseSearchClass):
"""
logging.info(('Set-up MCMC search with {} glitches for model {} on'
' data {}').format(self.nglitch, self.label,
self.sftlabel))
logging.info(
'Set-up MCMC search for model {} on data {}'.format(
self.label, self.sftlabel))
if os.path.isdir(outdir) is False:
os.mkdir(outdir)
self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
self.theta_prior['tstart'] = self.tstart
self.theta_prior['tend'] = self.tend
self.unpack_input_theta()
self.ndim = len(self.theta_keys)
self.sft_filepath = self.sftdir+'/*_'+self.sftlabel+"*sft"
......@@ -415,65 +411,43 @@ class MCMCGlitchSearch(BaseSearchClass):
def inititate_search_object(self):
logging.info('Setting up search object')
self.search = SemiCoherentGlitchSearch(
label=self.label, outdir=self.outdir, sftlabel=self.sftlabel,
sftdir=self.sftdir, tref=self.tref, tstart=self.tstart,
tend=self.tend, minCoverFreq=self.minCoverFreq,
self.search = ComputeFstat(
tref=self.tref, sftlabel=self.sftlabel,
sftdir=self.sftdir, minCoverFreq=self.minCoverFreq,
maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
sun_ephem=self.sun_ephem, detector=self.detector,
nglitch=self.nglitch)
sun_ephem=self.sun_ephem, detector=self.detector)
def logp(self, theta_vals, theta_prior, theta_keys, search):
if self.nglitch > 1:
ts = [self.tstart] + theta_vals[-self.nglitch:] + [self.tend]
if np.array_equal(ts, np.sort(ts)) is False:
return -np.inf
if any(np.diff(ts) < self.dtglitchmin):
return -np.inf
H = [self.Generic_lnprior(**theta_prior[key])(p) for p, key in
H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
zip(theta_vals, theta_keys)]
return np.sum(H)
def logl(self, theta, search):
for j, theta_i in enumerate(self.theta_idxs):
self.fixed_theta[theta_i] = theta[j]
FS = search.compute_nglitch_fstat(*self.fixed_theta)
FS = search.run_computefstatistic_single_point(*self.fixed_theta)
return FS
def unpack_input_theta(self):
glitch_keys = ['delta_F0', 'delta_F1', 'tglitch']
full_glitch_keys = list(np.array(
[[gk]*self.nglitch for gk in glitch_keys]).flatten())
full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']+full_glitch_keys
full_theta_keys = ['tstart', 'tend', 'F0', 'F1', 'F2', 'Alpha',
'Delta']
full_theta_keys_copy = copy.copy(full_theta_keys)
glitch_symbols = ['$\delta f$', '$\delta \dot{f}$', r'$t_{glitch}$']
full_glitch_symbols = list(np.array(
[[gs]*self.nglitch for gs in glitch_symbols]).flatten())
full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
r'$\delta$'] + full_glitch_symbols)
full_theta_symbols = ['_', '_', '$f$', '$\dot{f}$', '$\ddot{f}$',
r'$\alpha$', r'$\delta$']
self.theta_keys = []
fixed_theta_dict = {}
for key, val in self.theta_prior.iteritems():
if type(val) is dict:
fixed_theta_dict[key] = 0
if key in glitch_keys:
for i in range(self.nglitch):
self.theta_keys.append(key)
else:
self.theta_keys.append(key)
self.theta_keys.append(key)
elif type(val) in [float, int, np.float64]:
fixed_theta_dict[key] = val
else:
raise ValueError(
'Type {} of {} in theta not recognised'.format(
type(val), key))
if key in glitch_keys:
for i in range(self.nglitch):
full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
else:
full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
if len(full_theta_keys_copy) > 0:
raise ValueError(('Input dictionary `theta` is missing the'
......@@ -489,13 +463,6 @@ class MCMCGlitchSearch(BaseSearchClass):
self.theta_symbols = [self.theta_symbols[i] for i in idxs]
self.theta_keys = [self.theta_keys[i] for i in idxs]
# Correct for number of glitches in the idxs
self.theta_idxs = np.array(self.theta_idxs)
while np.sum(self.theta_idxs[:-1] == self.theta_idxs[1:]) > 0:
for i, idx in enumerate(self.theta_idxs):
if idx in self.theta_idxs[:i]:
self.theta_idxs[i] += 1
def check_initial_points(self, p0):
initial_priors = np.array([
self.logp(p, self.theta_prior, self.theta_keys, self.search)
......@@ -525,7 +492,8 @@ class MCMCGlitchSearch(BaseSearchClass):
logpargs=(self.theta_prior, self.theta_keys, self.search),
loglargs=(self.search,), betas=self.betas)
p0 = self.GenerateInitial()
p0 = self.generate_initial_p0()
p0 = self.apply_corrections_to_p0(p0)
self.check_initial_points(p0)
ninit_steps = len(self.nsteps) - 2
......@@ -534,11 +502,12 @@ class MCMCGlitchSearch(BaseSearchClass):
j, ninit_steps, n))
sampler.run_mcmc(p0, n)
fig, axes = self.PlotWalkers(sampler, symbols=self.theta_symbols)
fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols)
fig.savefig('{}/{}_init_{}_walkers.png'.format(
self.outdir, self.label, j))
p0 = self.get_new_p0(sampler, scatter_val=self.scatter_val)
p0 = self.apply_corrections_to_p0(p0)
self.check_initial_points(p0)
sampler.reset()
......@@ -548,7 +517,7 @@ class MCMCGlitchSearch(BaseSearchClass):
nburn+nprod))
sampler.run_mcmc(p0, nburn+nprod)
fig, axes = self.PlotWalkers(sampler, symbols=self.theta_symbols)
fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols)
fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label))
samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
......@@ -622,14 +591,14 @@ class MCMCGlitchSearch(BaseSearchClass):
ax = axes[i][i]
xlim = ax.get_xlim()
s = samples[:, i]
prior = self.Generic_lnprior(**self.theta_prior[key])
prior = self.generic_lnprior(**self.theta_prior[key])
x = np.linspace(s.min(), s.max(), 100)
ax2 = ax.twinx()
ax2.get_yaxis().set_visible(False)
ax2.plot(x, [prior(xi) for xi in x], '-r')
ax.set_xlim(xlim)
def Generic_lnprior(self, **kwargs):
def generic_lnprior(self, **kwargs):
""" Return a lambda function of the pdf
Parameters
......@@ -679,7 +648,7 @@ class MCMCGlitchSearch(BaseSearchClass):
logging.info("kwargs:", kwargs)
raise ValueError("Print unrecognise distribution")
def GenerateRV(self, **kwargs):
def generate_rv(self, **kwargs):
dist_type = kwargs.pop('type')
if dist_type == "unif":
return np.random.uniform(low=kwargs['lower'], high=kwargs['upper'])
......@@ -694,8 +663,8 @@ class MCMCGlitchSearch(BaseSearchClass):
else:
raise ValueError("dist_type {} unknown".format(dist_type))
def PlotWalkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
start=None, stop=None, draw_vline=None):
def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
start=None, stop=None, draw_vline=None):
""" Plot all the chains from a sampler """
shape = sampler.chain.shape
......@@ -725,38 +694,35 @@ class MCMCGlitchSearch(BaseSearchClass):
return fig, axes
def _generate_scattered_p0(self, p):
def apply_corrections_to_p0(self, p0):
""" Apply any correction to the initial p0 values """
return p0
def generate_scattered_p0(self, p):
""" Generate a set of p0s scattered about p """
p0 = [[p + scatter_val * p * np.random.randn(self.ndim)
p0 = [[p + self.scatter_val * p * np.random.randn(self.ndim)
for i in xrange(self.nwalkers)]
for j in xrange(self.ntemps)]
return p0
def _sort_p0_times(self, p0):
p0 = np.array(p0)
p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:], axis=2)
return p0
def GenerateInitial(self):
def generate_initial_p0(self):
""" Generate a set of init vals for the walkers """
if type(self.theta_initial) == dict:
p0 = [[[self.GenerateRV(**self.theta_initial[key])
p0 = [[[self.generate_rv(**self.theta_initial[key])
for key in self.theta_keys]
for i in range(self.nwalkers)]
for j in range(self.ntemps)]
elif self.theta_initial is None:
p0 = [[[self.GenerateRV(**self.theta_prior[key])
p0 = [[[self.generate_rv(**self.theta_prior[key])
for key in self.theta_keys]
for i in range(self.nwalkers)]
for j in range(self.ntemps)]
elif len(self.theta_initial) == self.ndim:
p0 = self._generate_scattered_p0(self.theta_initial)
p0 = self.generate_scattered_p0(self.theta_initial)
else:
raise ValueError('theta_initial not understood')
if self.nglitch > 1:
p0 = self._sort_p0_times(p0)
return p0
def get_new_p0(self, sampler, scatter_val=1e-3):
......@@ -780,8 +746,6 @@ class MCMCGlitchSearch(BaseSearchClass):
p = pF[np.nanargmax(lnp)]
p0 = self._generate_scattered_p0(p)
if self.nglitch > 1:
p0 = self._sort_p0_times(p0)
return p0
def get_save_data_dictionary(self):
......@@ -923,6 +887,164 @@ class MCMCGlitchSearch(BaseSearchClass):
k, d[k], d[k+'_std']))
class MCMCGlitchSearch(MCMCSearch):
""" MCMC search using the SemiCoherentGlitchSearch """
@initializer
def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref,
tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1,
nglitch=0, theta_initial=None, minCoverFreq=None,
maxCoverFreq=None, scatter_val=1e-4, betas=None,
detector=None, dtglitchmin=20*86400, earth_ephem=None,
sun_ephem=None):
"""
Parameters
label, outdir: str
A label and directory to read/write data from/to
sftlabel, sftdir: str
A label and directory in which to find the relevant sft file
theta_prior: dict
Dictionary of priors and fixed values for the search parameters.
For each parameters (key of the dict), if it is to be held fixed
the value should be the constant float, if it is be searched, the
value should be a dictionary of the prior.
theta_initial: dict, array, (None)
Either a dictionary of distribution about which to distribute the
initial walkers about, an array (from which the walkers will be
scattered by scatter_val, or None in which case the prior is used.
nglitch: int
The number of glitches to allow
tref, tstart, tend: int
GPS seconds of the reference time, start time and end time
nsteps: list (m,)
List specifying the number of steps to take, the last two entries
give the nburn and nprod of the 'production' run, all entries
before are for iterative initialisation steps (usually just one)
e.g. [1000, 1000, 500].
dtglitchmin: int
The minimum duration (in seconds) of a segment between two glitches
or a glitch and the start/end of the data
nwalkers, ntemps: int
Number of walkers and temperatures
minCoverFreq, maxCoverFreq: float
Minimum and maximum instantaneous frequency which will be covered
over the SFT time span as passed to CreateFstatInput
earth_ephem, sun_ephem: str
Paths of the two files containing positions of Earth and Sun,
respectively at evenly spaced times, as passed to CreateFstatInput
If None defaults defined in BaseSearchClass will be used
"""
logging.info(('Set-up MCMC glitch search with {} glitches for model {}'
' on data {}').format(self.nglitch, self.label,
self.sftlabel))
if os.path.isdir(outdir) is False:
os.mkdir(outdir)
self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
self.unpack_input_theta()
self.ndim = len(self.theta_keys)
self.sft_filepath = self.sftdir+'/*_'+self.sftlabel+"*sft"
if earth_ephem is None:
self.earth_ephem = self.earth_ephem_default
if sun_ephem is None:
self.sun_ephem = self.sun_ephem_default
if args.clean and os.path.isfile(self.pickle_path):
os.rename(self.pickle_path, self.pickle_path+".old")
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
def inititate_search_object(self):
logging.info('Setting up search object')
self.search = SemiCoherentGlitchSearch(
label=self.label, outdir=self.outdir, sftlabel=self.sftlabel,
sftdir=self.sftdir, tref=self.tref, tstart=self.tstart,
tend=self.tend, minCoverFreq=self.minCoverFreq,
maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
sun_ephem=self.sun_ephem, detector=self.detector,
nglitch=self.nglitch)
def logp(self, theta_vals, theta_prior, theta_keys, search):
if self.nglitch > 1:
ts = [self.tstart] + theta_vals[-self.nglitch:] + [self.tend]
if np.array_equal(ts, np.sort(ts)) is False:
return -np.inf
if any(np.diff(ts) < self.dtglitchmin):
return -np.inf
H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
zip(theta_vals, theta_keys)]
return np.sum(H)
def logl(self, theta, search):
for j, theta_i in enumerate(self.theta_idxs):
self.fixed_theta[theta_i] = theta[j]
FS = search.compute_nglitch_fstat(*self.fixed_theta)
return FS
def unpack_input_theta(self):
glitch_keys = ['delta_F0', 'delta_F1', 'tglitch']
full_glitch_keys = list(np.array(
[[gk]*self.nglitch for gk in glitch_keys]).flatten())
full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']+full_glitch_keys
full_theta_keys_copy = copy.copy(full_theta_keys)
glitch_symbols = ['$\delta f$', '$\delta \dot{f}$', r'$t_{glitch}$']
full_glitch_symbols = list(np.array(
[[gs]*self.nglitch for gs in glitch_symbols]).flatten())
full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
r'$\delta$'] + full_glitch_symbols)
self.theta_keys = []
fixed_theta_dict = {}
for key, val in self.theta_prior.iteritems():
if type(val) is dict:
fixed_theta_dict[key] = 0
if key in glitch_keys:
for i in range(self.nglitch):
self.theta_keys.append(key)
else:
self.theta_keys.append(key)
elif type(val) in [float, int, np.float64]:
fixed_theta_dict[key] = val
else:
raise ValueError(
'Type {} of {} in theta not recognised'.format(
type(val), key))
if key in glitch_keys:
for i in range(self.nglitch):
full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
else:
full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
if len(full_theta_keys_copy) > 0:
raise ValueError(('Input dictionary `theta` is missing the'
'following keys: {}').format(
full_theta_keys_copy))
self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]
idxs = np.argsort(self.theta_idxs)
self.theta_idxs = [self.theta_idxs[i] for i in idxs]
self.theta_symbols = [self.theta_symbols[i] for i in idxs]
self.theta_keys = [self.theta_keys[i] for i in idxs]
# Correct for number of glitches in the idxs
self.theta_idxs = np.array(self.theta_idxs)
while np.sum(self.theta_idxs[:-1] == self.theta_idxs[1:]) > 0:
for i, idx in enumerate(self.theta_idxs):
if idx in self.theta_idxs[:i]:
self.theta_idxs[i] += 1
def apply_corrections_to_p0(self, p0):
p0 = np.array(p0)
if self.nglitch > 1:
p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:],
axis=2)
return p0
class GridGlitchSearch(BaseSearchClass):
""" Gridded search using the SemiCoherentGlitchSearch """
@initializer
......
......@@ -137,7 +137,7 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase):
self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3)
class TestMCMCGlitchSearch(unittest.TestCase):
class TestMCMCSearch(unittest.TestCase):
label = "MCMCTest"
outdir = 'TestData'
......@@ -165,13 +165,12 @@ class TestMCMCGlitchSearch(unittest.TestCase):
Writer.make_data()
predicted_FS = Writer.predict_fstat()
theta = {'delta_F0': 0, 'delta_F1': 0, 'tglitch': tend,
'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)},
theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)},
'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-9*F1)},
'F2': F2, 'Alpha': Alpha, 'Delta': Delta}
search = pyfstat.MCMCGlitchSearch(
label=self.label, outdir=self.outdir, theta=theta, tref=tref,
search = pyfstat.MCMCSearch(
label=self.label, outdir=self.outdir, theta_prior=theta, tref=tref,
sftlabel=self.label, sftdir=self.outdir,
tstart=tstart, tend=tend, nsteps=[100, 100], nwalkers=100,
ntemps=1)
......@@ -181,7 +180,8 @@ class TestMCMCGlitchSearch(unittest.TestCase):
print('Predicted twoF is {} while recovered is {}'.format(
predicted_FS, FS))
self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3)
self.assertTrue(
FS > predicted_FS or np.abs((FS-predicted_FS))/predicted_FS < 0.3)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment