Skip to content
Snippets Groups Projects
Commit 8435f54d authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Splits up the MCMC classes

This makes the MCMCGlitchSearch a subclass of the more general
MCMCSearch
parent 1e111130
No related branches found
No related tags found
No related merge requests found
...@@ -346,15 +346,14 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): ...@@ -346,15 +346,14 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
return twoFsegA + twoFsegB return twoFsegA + twoFsegB
class MCMCGlitchSearch(BaseSearchClass): class MCMCSearch(BaseSearchClass):
""" MCMC search using the SemiCoherentGlitchSearch """ """ MCMC search using ComputeFstat"""
@initializer @initializer
def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref, def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref,
tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1, tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1,
nglitch=0, theta_initial=None, minCoverFreq=None, theta_initial=None, minCoverFreq=None,
maxCoverFreq=None, scatter_val=1e-4, betas=None, maxCoverFreq=None, scatter_val=1e-4, betas=None,
detector=None, dtglitchmin=20*86400, earth_ephem=None, detector=None, earth_ephem=None, sun_ephem=None):
sun_ephem=None):
""" """
Parameters Parameters
label, outdir: str label, outdir: str
...@@ -370,8 +369,6 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -370,8 +369,6 @@ class MCMCGlitchSearch(BaseSearchClass):
Either a dictionary of distribution about which to distribute the Either a dictionary of distribution about which to distribute the
initial walkers about, an array (from which the walkers will be initial walkers about, an array (from which the walkers will be
scattered by scatter_val, or None in which case the prior is used. scattered by scatter_val, or None in which case the prior is used.
nglitch: int
The number of glitches to allow
tref, tstart, tend: int tref, tstart, tend: int
GPS seconds of the reference time, start time and end time GPS seconds of the reference time, start time and end time
nsteps: list (m,) nsteps: list (m,)
...@@ -379,9 +376,6 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -379,9 +376,6 @@ class MCMCGlitchSearch(BaseSearchClass):
give the nburn and nprod of the 'production' run, all entries give the nburn and nprod of the 'production' run, all entries
before are for iterative initialisation steps (usually just one) before are for iterative initialisation steps (usually just one)
e.g. [1000, 1000, 500]. e.g. [1000, 1000, 500].
dtglitchmin: int
The minimum duration (in seconds) of a segment between two glitches
or a glitch and the start/end of the data
nwalkers, ntemps: int nwalkers, ntemps: int
Number of walkers and temperatures Number of walkers and temperatures
minCoverFreq, maxCoverFreq: float minCoverFreq, maxCoverFreq: float
...@@ -394,12 +388,14 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -394,12 +388,14 @@ class MCMCGlitchSearch(BaseSearchClass):
""" """
logging.info(('Set-up MCMC search with {} glitches for model {} on' logging.info(
' data {}').format(self.nglitch, self.label, 'Set-up MCMC search for model {} on data {}'.format(
self.sftlabel)) self.label, self.sftlabel))
if os.path.isdir(outdir) is False: if os.path.isdir(outdir) is False:
os.mkdir(outdir) os.mkdir(outdir)
self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label) self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
self.theta_prior['tstart'] = self.tstart
self.theta_prior['tend'] = self.tend
self.unpack_input_theta() self.unpack_input_theta()
self.ndim = len(self.theta_keys) self.ndim = len(self.theta_keys)
self.sft_filepath = self.sftdir+'/*_'+self.sftlabel+"*sft" self.sft_filepath = self.sftdir+'/*_'+self.sftlabel+"*sft"
...@@ -415,53 +411,35 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -415,53 +411,35 @@ class MCMCGlitchSearch(BaseSearchClass):
def inititate_search_object(self): def inititate_search_object(self):
logging.info('Setting up search object') logging.info('Setting up search object')
self.search = SemiCoherentGlitchSearch( self.search = ComputeFstat(
label=self.label, outdir=self.outdir, sftlabel=self.sftlabel, tref=self.tref, sftlabel=self.sftlabel,
sftdir=self.sftdir, tref=self.tref, tstart=self.tstart, sftdir=self.sftdir, minCoverFreq=self.minCoverFreq,
tend=self.tend, minCoverFreq=self.minCoverFreq,
maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem, maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
sun_ephem=self.sun_ephem, detector=self.detector, sun_ephem=self.sun_ephem, detector=self.detector)
nglitch=self.nglitch)
def logp(self, theta_vals, theta_prior, theta_keys, search): def logp(self, theta_vals, theta_prior, theta_keys, search):
if self.nglitch > 1: H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
ts = [self.tstart] + theta_vals[-self.nglitch:] + [self.tend]
if np.array_equal(ts, np.sort(ts)) is False:
return -np.inf
if any(np.diff(ts) < self.dtglitchmin):
return -np.inf
H = [self.Generic_lnprior(**theta_prior[key])(p) for p, key in
zip(theta_vals, theta_keys)] zip(theta_vals, theta_keys)]
return np.sum(H) return np.sum(H)
def logl(self, theta, search): def logl(self, theta, search):
for j, theta_i in enumerate(self.theta_idxs): for j, theta_i in enumerate(self.theta_idxs):
self.fixed_theta[theta_i] = theta[j] self.fixed_theta[theta_i] = theta[j]
FS = search.compute_nglitch_fstat(*self.fixed_theta) FS = search.run_computefstatistic_single_point(*self.fixed_theta)
return FS return FS
def unpack_input_theta(self): def unpack_input_theta(self):
glitch_keys = ['delta_F0', 'delta_F1', 'tglitch'] full_theta_keys = ['tstart', 'tend', 'F0', 'F1', 'F2', 'Alpha',
full_glitch_keys = list(np.array( 'Delta']
[[gk]*self.nglitch for gk in glitch_keys]).flatten())
full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']+full_glitch_keys
full_theta_keys_copy = copy.copy(full_theta_keys) full_theta_keys_copy = copy.copy(full_theta_keys)
glitch_symbols = ['$\delta f$', '$\delta \dot{f}$', r'$t_{glitch}$'] full_theta_symbols = ['_', '_', '$f$', '$\dot{f}$', '$\ddot{f}$',
full_glitch_symbols = list(np.array( r'$\alpha$', r'$\delta$']
[[gs]*self.nglitch for gs in glitch_symbols]).flatten())
full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
r'$\delta$'] + full_glitch_symbols)
self.theta_keys = [] self.theta_keys = []
fixed_theta_dict = {} fixed_theta_dict = {}
for key, val in self.theta_prior.iteritems(): for key, val in self.theta_prior.iteritems():
if type(val) is dict: if type(val) is dict:
fixed_theta_dict[key] = 0 fixed_theta_dict[key] = 0
if key in glitch_keys:
for i in range(self.nglitch):
self.theta_keys.append(key)
else:
self.theta_keys.append(key) self.theta_keys.append(key)
elif type(val) in [float, int, np.float64]: elif type(val) in [float, int, np.float64]:
fixed_theta_dict[key] = val fixed_theta_dict[key] = val
...@@ -469,10 +447,6 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -469,10 +447,6 @@ class MCMCGlitchSearch(BaseSearchClass):
raise ValueError( raise ValueError(
'Type {} of {} in theta not recognised'.format( 'Type {} of {} in theta not recognised'.format(
type(val), key)) type(val), key))
if key in glitch_keys:
for i in range(self.nglitch):
full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
else:
full_theta_keys_copy.pop(full_theta_keys_copy.index(key)) full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
if len(full_theta_keys_copy) > 0: if len(full_theta_keys_copy) > 0:
...@@ -489,13 +463,6 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -489,13 +463,6 @@ class MCMCGlitchSearch(BaseSearchClass):
self.theta_symbols = [self.theta_symbols[i] for i in idxs] self.theta_symbols = [self.theta_symbols[i] for i in idxs]
self.theta_keys = [self.theta_keys[i] for i in idxs] self.theta_keys = [self.theta_keys[i] for i in idxs]
# Correct for number of glitches in the idxs
self.theta_idxs = np.array(self.theta_idxs)
while np.sum(self.theta_idxs[:-1] == self.theta_idxs[1:]) > 0:
for i, idx in enumerate(self.theta_idxs):
if idx in self.theta_idxs[:i]:
self.theta_idxs[i] += 1
def check_initial_points(self, p0): def check_initial_points(self, p0):
initial_priors = np.array([ initial_priors = np.array([
self.logp(p, self.theta_prior, self.theta_keys, self.search) self.logp(p, self.theta_prior, self.theta_keys, self.search)
...@@ -525,7 +492,8 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -525,7 +492,8 @@ class MCMCGlitchSearch(BaseSearchClass):
logpargs=(self.theta_prior, self.theta_keys, self.search), logpargs=(self.theta_prior, self.theta_keys, self.search),
loglargs=(self.search,), betas=self.betas) loglargs=(self.search,), betas=self.betas)
p0 = self.GenerateInitial() p0 = self.generate_initial_p0()
p0 = self.apply_corrections_to_p0(p0)
self.check_initial_points(p0) self.check_initial_points(p0)
ninit_steps = len(self.nsteps) - 2 ninit_steps = len(self.nsteps) - 2
...@@ -534,11 +502,12 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -534,11 +502,12 @@ class MCMCGlitchSearch(BaseSearchClass):
j, ninit_steps, n)) j, ninit_steps, n))
sampler.run_mcmc(p0, n) sampler.run_mcmc(p0, n)
fig, axes = self.PlotWalkers(sampler, symbols=self.theta_symbols) fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols)
fig.savefig('{}/{}_init_{}_walkers.png'.format( fig.savefig('{}/{}_init_{}_walkers.png'.format(
self.outdir, self.label, j)) self.outdir, self.label, j))
p0 = self.get_new_p0(sampler, scatter_val=self.scatter_val) p0 = self.get_new_p0(sampler, scatter_val=self.scatter_val)
p0 = self.apply_corrections_to_p0(p0)
self.check_initial_points(p0) self.check_initial_points(p0)
sampler.reset() sampler.reset()
...@@ -548,7 +517,7 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -548,7 +517,7 @@ class MCMCGlitchSearch(BaseSearchClass):
nburn+nprod)) nburn+nprod))
sampler.run_mcmc(p0, nburn+nprod) sampler.run_mcmc(p0, nburn+nprod)
fig, axes = self.PlotWalkers(sampler, symbols=self.theta_symbols) fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols)
fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label)) fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label))
samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim)) samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
...@@ -622,14 +591,14 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -622,14 +591,14 @@ class MCMCGlitchSearch(BaseSearchClass):
ax = axes[i][i] ax = axes[i][i]
xlim = ax.get_xlim() xlim = ax.get_xlim()
s = samples[:, i] s = samples[:, i]
prior = self.Generic_lnprior(**self.theta_prior[key]) prior = self.generic_lnprior(**self.theta_prior[key])
x = np.linspace(s.min(), s.max(), 100) x = np.linspace(s.min(), s.max(), 100)
ax2 = ax.twinx() ax2 = ax.twinx()
ax2.get_yaxis().set_visible(False) ax2.get_yaxis().set_visible(False)
ax2.plot(x, [prior(xi) for xi in x], '-r') ax2.plot(x, [prior(xi) for xi in x], '-r')
ax.set_xlim(xlim) ax.set_xlim(xlim)
def Generic_lnprior(self, **kwargs): def generic_lnprior(self, **kwargs):
""" Return a lambda function of the pdf """ Return a lambda function of the pdf
Parameters Parameters
...@@ -679,7 +648,7 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -679,7 +648,7 @@ class MCMCGlitchSearch(BaseSearchClass):
logging.info("kwargs:", kwargs) logging.info("kwargs:", kwargs)
raise ValueError("Print unrecognise distribution") raise ValueError("Print unrecognise distribution")
def GenerateRV(self, **kwargs): def generate_rv(self, **kwargs):
dist_type = kwargs.pop('type') dist_type = kwargs.pop('type')
if dist_type == "unif": if dist_type == "unif":
return np.random.uniform(low=kwargs['lower'], high=kwargs['upper']) return np.random.uniform(low=kwargs['lower'], high=kwargs['upper'])
...@@ -694,7 +663,7 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -694,7 +663,7 @@ class MCMCGlitchSearch(BaseSearchClass):
else: else:
raise ValueError("dist_type {} unknown".format(dist_type)) raise ValueError("dist_type {} unknown".format(dist_type))
def PlotWalkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0, def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
start=None, stop=None, draw_vline=None): start=None, stop=None, draw_vline=None):
""" Plot all the chains from a sampler """ """ Plot all the chains from a sampler """
...@@ -725,38 +694,35 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -725,38 +694,35 @@ class MCMCGlitchSearch(BaseSearchClass):
return fig, axes return fig, axes
def _generate_scattered_p0(self, p): def apply_corrections_to_p0(self, p0):
""" Apply any correction to the initial p0 values """
return p0
def generate_scattered_p0(self, p):
""" Generate a set of p0s scattered about p """ """ Generate a set of p0s scattered about p """
p0 = [[p + scatter_val * p * np.random.randn(self.ndim) p0 = [[p + self.scatter_val * p * np.random.randn(self.ndim)
for i in xrange(self.nwalkers)] for i in xrange(self.nwalkers)]
for j in xrange(self.ntemps)] for j in xrange(self.ntemps)]
return p0 return p0
def _sort_p0_times(self, p0): def generate_initial_p0(self):
p0 = np.array(p0)
p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:], axis=2)
return p0
def GenerateInitial(self):
""" Generate a set of init vals for the walkers """ """ Generate a set of init vals for the walkers """
if type(self.theta_initial) == dict: if type(self.theta_initial) == dict:
p0 = [[[self.GenerateRV(**self.theta_initial[key]) p0 = [[[self.generate_rv(**self.theta_initial[key])
for key in self.theta_keys] for key in self.theta_keys]
for i in range(self.nwalkers)] for i in range(self.nwalkers)]
for j in range(self.ntemps)] for j in range(self.ntemps)]
elif self.theta_initial is None: elif self.theta_initial is None:
p0 = [[[self.GenerateRV(**self.theta_prior[key]) p0 = [[[self.generate_rv(**self.theta_prior[key])
for key in self.theta_keys] for key in self.theta_keys]
for i in range(self.nwalkers)] for i in range(self.nwalkers)]
for j in range(self.ntemps)] for j in range(self.ntemps)]
elif len(self.theta_initial) == self.ndim: elif len(self.theta_initial) == self.ndim:
p0 = self._generate_scattered_p0(self.theta_initial) p0 = self.generate_scattered_p0(self.theta_initial)
else: else:
raise ValueError('theta_initial not understood') raise ValueError('theta_initial not understood')
if self.nglitch > 1:
p0 = self._sort_p0_times(p0)
return p0 return p0
def get_new_p0(self, sampler, scatter_val=1e-3): def get_new_p0(self, sampler, scatter_val=1e-3):
...@@ -780,8 +746,6 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -780,8 +746,6 @@ class MCMCGlitchSearch(BaseSearchClass):
p = pF[np.nanargmax(lnp)] p = pF[np.nanargmax(lnp)]
p0 = self._generate_scattered_p0(p) p0 = self._generate_scattered_p0(p)
if self.nglitch > 1:
p0 = self._sort_p0_times(p0)
return p0 return p0
def get_save_data_dictionary(self): def get_save_data_dictionary(self):
...@@ -923,6 +887,164 @@ class MCMCGlitchSearch(BaseSearchClass): ...@@ -923,6 +887,164 @@ class MCMCGlitchSearch(BaseSearchClass):
k, d[k], d[k+'_std'])) k, d[k], d[k+'_std']))
class MCMCGlitchSearch(MCMCSearch):
""" MCMC search using the SemiCoherentGlitchSearch """
@initializer
def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref,
tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1,
nglitch=0, theta_initial=None, minCoverFreq=None,
maxCoverFreq=None, scatter_val=1e-4, betas=None,
detector=None, dtglitchmin=20*86400, earth_ephem=None,
sun_ephem=None):
"""
Parameters
label, outdir: str
A label and directory to read/write data from/to
sftlabel, sftdir: str
A label and directory in which to find the relevant sft file
theta_prior: dict
Dictionary of priors and fixed values for the search parameters.
For each parameters (key of the dict), if it is to be held fixed
the value should be the constant float, if it is be searched, the
value should be a dictionary of the prior.
theta_initial: dict, array, (None)
Either a dictionary of distribution about which to distribute the
initial walkers about, an array (from which the walkers will be
scattered by scatter_val, or None in which case the prior is used.
nglitch: int
The number of glitches to allow
tref, tstart, tend: int
GPS seconds of the reference time, start time and end time
nsteps: list (m,)
List specifying the number of steps to take, the last two entries
give the nburn and nprod of the 'production' run, all entries
before are for iterative initialisation steps (usually just one)
e.g. [1000, 1000, 500].
dtglitchmin: int
The minimum duration (in seconds) of a segment between two glitches
or a glitch and the start/end of the data
nwalkers, ntemps: int
Number of walkers and temperatures
minCoverFreq, maxCoverFreq: float
Minimum and maximum instantaneous frequency which will be covered
over the SFT time span as passed to CreateFstatInput
earth_ephem, sun_ephem: str
Paths of the two files containing positions of Earth and Sun,
respectively at evenly spaced times, as passed to CreateFstatInput
If None defaults defined in BaseSearchClass will be used
"""
logging.info(('Set-up MCMC glitch search with {} glitches for model {}'
' on data {}').format(self.nglitch, self.label,
self.sftlabel))
if os.path.isdir(outdir) is False:
os.mkdir(outdir)
self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
self.unpack_input_theta()
self.ndim = len(self.theta_keys)
self.sft_filepath = self.sftdir+'/*_'+self.sftlabel+"*sft"
if earth_ephem is None:
self.earth_ephem = self.earth_ephem_default
if sun_ephem is None:
self.sun_ephem = self.sun_ephem_default
if args.clean and os.path.isfile(self.pickle_path):
os.rename(self.pickle_path, self.pickle_path+".old")
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
def inititate_search_object(self):
logging.info('Setting up search object')
self.search = SemiCoherentGlitchSearch(
label=self.label, outdir=self.outdir, sftlabel=self.sftlabel,
sftdir=self.sftdir, tref=self.tref, tstart=self.tstart,
tend=self.tend, minCoverFreq=self.minCoverFreq,
maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
sun_ephem=self.sun_ephem, detector=self.detector,
nglitch=self.nglitch)
def logp(self, theta_vals, theta_prior, theta_keys, search):
if self.nglitch > 1:
ts = [self.tstart] + theta_vals[-self.nglitch:] + [self.tend]
if np.array_equal(ts, np.sort(ts)) is False:
return -np.inf
if any(np.diff(ts) < self.dtglitchmin):
return -np.inf
H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
zip(theta_vals, theta_keys)]
return np.sum(H)
def logl(self, theta, search):
for j, theta_i in enumerate(self.theta_idxs):
self.fixed_theta[theta_i] = theta[j]
FS = search.compute_nglitch_fstat(*self.fixed_theta)
return FS
def unpack_input_theta(self):
glitch_keys = ['delta_F0', 'delta_F1', 'tglitch']
full_glitch_keys = list(np.array(
[[gk]*self.nglitch for gk in glitch_keys]).flatten())
full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']+full_glitch_keys
full_theta_keys_copy = copy.copy(full_theta_keys)
glitch_symbols = ['$\delta f$', '$\delta \dot{f}$', r'$t_{glitch}$']
full_glitch_symbols = list(np.array(
[[gs]*self.nglitch for gs in glitch_symbols]).flatten())
full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
r'$\delta$'] + full_glitch_symbols)
self.theta_keys = []
fixed_theta_dict = {}
for key, val in self.theta_prior.iteritems():
if type(val) is dict:
fixed_theta_dict[key] = 0
if key in glitch_keys:
for i in range(self.nglitch):
self.theta_keys.append(key)
else:
self.theta_keys.append(key)
elif type(val) in [float, int, np.float64]:
fixed_theta_dict[key] = val
else:
raise ValueError(
'Type {} of {} in theta not recognised'.format(
type(val), key))
if key in glitch_keys:
for i in range(self.nglitch):
full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
else:
full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
if len(full_theta_keys_copy) > 0:
raise ValueError(('Input dictionary `theta` is missing the'
'following keys: {}').format(
full_theta_keys_copy))
self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]
idxs = np.argsort(self.theta_idxs)
self.theta_idxs = [self.theta_idxs[i] for i in idxs]
self.theta_symbols = [self.theta_symbols[i] for i in idxs]
self.theta_keys = [self.theta_keys[i] for i in idxs]
# Correct for number of glitches in the idxs
self.theta_idxs = np.array(self.theta_idxs)
while np.sum(self.theta_idxs[:-1] == self.theta_idxs[1:]) > 0:
for i, idx in enumerate(self.theta_idxs):
if idx in self.theta_idxs[:i]:
self.theta_idxs[i] += 1
def apply_corrections_to_p0(self, p0):
p0 = np.array(p0)
if self.nglitch > 1:
p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:],
axis=2)
return p0
class GridGlitchSearch(BaseSearchClass): class GridGlitchSearch(BaseSearchClass):
""" Gridded search using the SemiCoherentGlitchSearch """ """ Gridded search using the SemiCoherentGlitchSearch """
@initializer @initializer
......
...@@ -137,7 +137,7 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase): ...@@ -137,7 +137,7 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase):
self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3) self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3)
class TestMCMCGlitchSearch(unittest.TestCase): class TestMCMCSearch(unittest.TestCase):
label = "MCMCTest" label = "MCMCTest"
outdir = 'TestData' outdir = 'TestData'
...@@ -165,13 +165,12 @@ class TestMCMCGlitchSearch(unittest.TestCase): ...@@ -165,13 +165,12 @@ class TestMCMCGlitchSearch(unittest.TestCase):
Writer.make_data() Writer.make_data()
predicted_FS = Writer.predict_fstat() predicted_FS = Writer.predict_fstat()
theta = {'delta_F0': 0, 'delta_F1': 0, 'tglitch': tend, theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)},
'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)},
'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-9*F1)}, 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-9*F1)},
'F2': F2, 'Alpha': Alpha, 'Delta': Delta} 'F2': F2, 'Alpha': Alpha, 'Delta': Delta}
search = pyfstat.MCMCGlitchSearch( search = pyfstat.MCMCSearch(
label=self.label, outdir=self.outdir, theta=theta, tref=tref, label=self.label, outdir=self.outdir, theta_prior=theta, tref=tref,
sftlabel=self.label, sftdir=self.outdir, sftlabel=self.label, sftdir=self.outdir,
tstart=tstart, tend=tend, nsteps=[100, 100], nwalkers=100, tstart=tstart, tend=tend, nsteps=[100, 100], nwalkers=100,
ntemps=1) ntemps=1)
...@@ -181,7 +180,8 @@ class TestMCMCGlitchSearch(unittest.TestCase): ...@@ -181,7 +180,8 @@ class TestMCMCGlitchSearch(unittest.TestCase):
print('Predicted twoF is {} while recovered is {}'.format( print('Predicted twoF is {} while recovered is {}'.format(
predicted_FS, FS)) predicted_FS, FS))
self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3) self.assertTrue(
FS > predicted_FS or np.abs((FS-predicted_FS))/predicted_FS < 0.3)
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment