Skip to content
Snippets Groups Projects
Commit 0624fc7e authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adds a semi-coherent search class

- Also changes the parameter vector in the normal fully coherent search
  to remove the tstart and tend. This is done by wrappin the
  run_computefstatistic_single_point in a compute_fullycoherent_det_stat
  function
- Removes theta0 from the saved data dict for the MCMCSearch
- Adds example usage of the semi-coherent search
parent c323cf14
No related branches found
No related tags found
No related merge requests found
import pyfstat
F0 = 30.0
F1 = -1e-10
F2 = 0
Alpha = 5e-3
Delta = 6e-2
tref = 362750407.0
tstart = 1000000000
duration = 100*86400
tend = tstart + duration
theta_prior = {'F0': {'type': 'unif', 'lower': F0*(1-1e-6), 'upper': F0*(1+1e-5)},
'F1': {'type': 'unif', 'lower': F1*(1+1e-2), 'upper': F1*(1-1e-2)},
'F2': F2,
'Alpha': Alpha,
'Delta': Delta
}
ntemps = 1
log10temperature_min = -1
nwalkers = 100
nsteps = [500, 500, 500]
mcmc = pyfstat.MCMCSemiCoherentSearch(
label='semi_coherent_search_using_MCMC', outdir='data', nsegs=20,
sftfilepath='data/*basic*sft', theta_prior=theta_prior, tref=tref,
minStartTime=tstart, maxStartTime=tend, nsteps=nsteps, nwalkers=nwalkers,
ntemps=ntemps, log10temperature_min=log10temperature_min)
mcmc.run()
mcmc.plot_corner(add_prior=True)
mcmc.print_summary()
...@@ -330,6 +330,15 @@ class ComputeFstat(object): ...@@ -330,6 +330,15 @@ class ComputeFstat(object):
self.windowRange.tauBand = 0 self.windowRange.tauBand = 0
self.windowRange.dtau = 1 self.windowRange.dtau = 1
def compute_fullycoherent_det_stat_single_point(
self, F0, F1, F2, Alpha, Delta, asini=None, period=None, ecc=None,
tp=None, argp=None):
""" Compute the fully-coherent det. statistic at a single point """
return self.run_computefstatistic_single_point(
self.minStartTime, self.maxStartTime, F0, F1, F2, Alpha, Delta,
asini, period, ecc, tp, argp)
def run_computefstatistic_single_point(self, tstart, tend, F0, F1, def run_computefstatistic_single_point(self, tstart, tend, F0, F1,
F2, Alpha, Delta, asini=None, F2, Alpha, Delta, asini=None,
period=None, ecc=None, tp=None, period=None, ecc=None, tp=None,
...@@ -431,6 +440,54 @@ class ComputeFstat(object): ...@@ -431,6 +440,54 @@ class ComputeFstat(object):
return ax return ax
class SemiCoherentSearch(BaseSearchClass, ComputeFstat):
""" A semi-coherent search """
@initializer
def __init__(self, label, outdir, tref, nsegs=None, sftfilepath=None,
binary=False, BSGL=False, minStartTime=None,
maxStartTime=None, minCoverFreq=None, maxCoverFreq=None,
detector=None, earth_ephem=None, sun_ephem=None):
"""
Parameters
----------
label, outdir: str
A label and directory to read/write data from/to.
tref, tstart, tend: int
GPS seconds of the reference time, and start and end of the data.
nsegs: int
The (fixed) number of segments
sftfilepath: str
File patern to match SFTs
For all other parameters, see pyfstat.ComputeFStat.
"""
self.fs_file_name = "{}/{}_FS.dat".format(self.outdir, self.label)
if self.earth_ephem is None:
self.earth_ephem = self.earth_ephem_default
if self.sun_ephem is None:
self.sun_ephem = self.sun_ephem_default
self.transient = True
self.init_computefstatistic_single_point()
self.init_semicoherent_parameters()
def init_semicoherent_parameters(self):
logging.info('Initialise semicoherent parameters')
self.tboundaries = np.linspace(self.minStartTime, self.maxStartTime,
self.nsegs+1)
def compute_nseg_fstat(self, F0, F1, F2, Alpha, Delta):
""" Returns the semi-coherent summed twoF """
twoFvals = [self.run_computefstatistic_single_point(
self.tboundaries[i], self.tboundaries[i+1], F0, F1, F2, Alpha,
Delta)
for i in range(self.nsegs)]
return np.sum(twoFvals)
class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
""" A semi-coherent glitch search """ A semi-coherent glitch search
...@@ -543,7 +600,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -543,7 +600,7 @@ class MCMCSearch(BaseSearchClass):
log10temperature_min=-5, theta_initial=None, scatter_val=1e-10, log10temperature_min=-5, theta_initial=None, scatter_val=1e-10,
binary=False, BSGL=False, minCoverFreq=None, binary=False, BSGL=False, minCoverFreq=None,
maxCoverFreq=None, detector=None, earth_ephem=None, maxCoverFreq=None, detector=None, earth_ephem=None,
sun_ephem=None, theta0_idx=0): sun_ephem=None):
""" """
Parameters Parameters
label, outdir: str label, outdir: str
...@@ -597,8 +654,6 @@ class MCMCSearch(BaseSearchClass): ...@@ -597,8 +654,6 @@ class MCMCSearch(BaseSearchClass):
'Set-up MCMC search for model {} on data {}'.format( 'Set-up MCMC search for model {} on data {}'.format(
self.label, self.sftfilepath)) self.label, self.sftfilepath))
self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label) self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
self.theta_prior['tstart'] = self.tstart
self.theta_prior['tend'] = self.tend
self.unpack_input_theta() self.unpack_input_theta()
self.ndim = len(self.theta_keys) self.ndim = len(self.theta_keys)
if self.log10temperature_min: if self.log10temperature_min:
...@@ -644,19 +699,19 @@ class MCMCSearch(BaseSearchClass): ...@@ -644,19 +699,19 @@ class MCMCSearch(BaseSearchClass):
def logl(self, theta, search): def logl(self, theta, search):
for j, theta_i in enumerate(self.theta_idxs): for j, theta_i in enumerate(self.theta_idxs):
self.fixed_theta[theta_i] = theta[j] self.fixed_theta[theta_i] = theta[j]
FS = search.run_computefstatistic_single_point(*self.fixed_theta) FS = search.compute_fullycoherent_det_stat_single_point(
*self.fixed_theta)
return FS return FS
def unpack_input_theta(self): def unpack_input_theta(self):
full_theta_keys = ['tstart', 'tend', 'F0', 'F1', 'F2', 'Alpha', full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']
'Delta']
if self.binary: if self.binary:
full_theta_keys += [ full_theta_keys += [
'asini', 'period', 'ecc', 'tp', 'argp'] 'asini', 'period', 'ecc', 'tp', 'argp']
full_theta_keys_copy = copy.copy(full_theta_keys) full_theta_keys_copy = copy.copy(full_theta_keys)
full_theta_symbols = ['_', '_', '$f$', '$\dot{f}$', '$\ddot{f}$', full_theta_symbols = ['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
r'$\alpha$', r'$\delta$'] r'$\delta$']
if self.binary: if self.binary:
full_theta_symbols += [ full_theta_symbols += [
'asini', 'period', 'period', 'ecc', 'tp', 'argp'] 'asini', 'period', 'period', 'ecc', 'tp', 'argp']
...@@ -1175,7 +1230,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -1175,7 +1230,7 @@ class MCMCSearch(BaseSearchClass):
ntemps=self.ntemps, theta_keys=self.theta_keys, ntemps=self.ntemps, theta_keys=self.theta_keys,
theta_prior=self.theta_prior, scatter_val=self.scatter_val, theta_prior=self.theta_prior, scatter_val=self.scatter_val,
log10temperature_min=self.log10temperature_min, log10temperature_min=self.log10temperature_min,
theta0_idx=self.theta0_idx, BSGL=self.BSGL) BSGL=self.BSGL)
return d return d
def save_data(self, sampler, samples, lnprobs, lnlikes): def save_data(self, sampler, samples, lnprobs, lnlikes):
...@@ -1341,6 +1396,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -1341,6 +1396,7 @@ class MCMCSearch(BaseSearchClass):
max_twoFd, max_twoF = self.get_max_twoF() max_twoFd, max_twoF = self.get_max_twoF()
median_std_d = self.get_median_stds() median_std_d = self.get_median_stds()
print('\nSummary:') print('\nSummary:')
if hasattr(self, 'theta0_idx'):
print('theta0 index: {}'.format(self.theta0_idx)) print('theta0 index: {}'.format(self.theta0_idx))
print('Max twoF: {} with parameters:'.format(max_twoF)) print('Max twoF: {} with parameters:'.format(max_twoF))
for k in np.sort(max_twoFd.keys()): for k in np.sort(max_twoFd.keys()):
...@@ -1636,6 +1692,14 @@ _ sftfilepath: str ...@@ -1636,6 +1692,14 @@ _ sftfilepath: str
if idx in self.theta_idxs[:i]: if idx in self.theta_idxs[:i]:
self.theta_idxs[i] += 1 self.theta_idxs[i] += 1
def get_save_data_dictionary(self):
d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
ntemps=self.ntemps, theta_keys=self.theta_keys,
theta_prior=self.theta_prior, scatter_val=self.scatter_val,
log10temperature_min=self.log10temperature_min,
theta0_idx=self.theta0_idx, BSGL=self.BSGL)
return d
def apply_corrections_to_p0(self, p0): def apply_corrections_to_p0(self, p0):
p0 = np.array(p0) p0 = np.array(p0)
if self.nglitch > 1: if self.nglitch > 1:
...@@ -1693,6 +1757,65 @@ _ sftfilepath: str ...@@ -1693,6 +1757,65 @@ _ sftfilepath: str
fig.savefig('{}/{}_twoFcumulative.png'.format(self.outdir, self.label)) fig.savefig('{}/{}_twoFcumulative.png'.format(self.outdir, self.label))
class MCMCSemiCoherentSearch(MCMCSearch):
""" MCMC search for a signal using the semi-coherent ComputeFstat """
@initializer
def __init__(self, label, outdir, sftfilepath, theta_prior, tref,
nsegs=None, nsteps=[100, 100, 100], nwalkers=100, binary=False,
ntemps=1, log10temperature_min=-5, theta_initial=None,
scatter_val=1e-10, detector=None, BSGL=False,
minStartTime=None, maxStartTime=None, minCoverFreq=None,
maxCoverFreq=None, earth_ephem=None, sun_ephem=None):
"""
"""
if os.path.isdir(outdir) is False:
os.mkdir(outdir)
self.add_log_file()
logging.info(('Set-up MCMC semi-coherent search for model {} on data'
'{}').format(
self.label, self.sftfilepath))
self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
self.unpack_input_theta()
self.ndim = len(self.theta_keys)
if self.log10temperature_min:
self.betas = np.logspace(0, self.log10temperature_min, self.ntemps)
else:
self.betas = None
if earth_ephem is None:
self.earth_ephem = self.earth_ephem_default
if sun_ephem is None:
self.sun_ephem = self.sun_ephem_default
if args.clean and os.path.isfile(self.pickle_path):
os.rename(self.pickle_path, self.pickle_path+".old")
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
self.log_input()
def inititate_search_object(self):
logging.info('Setting up search object')
self.search = SemiCoherentSearch(
label=self.label, outdir=self.outdir, tref=self.tref,
nsegs=self.nsegs, sftfilepath=self.sftfilepath, binary=self.binary,
BSGL=self.BSGL, minStartTime=self.minStartTime,
maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
maxCoverFreq=self.maxCoverFreq, detector=self.detector,
earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem)
def logp(self, theta_vals, theta_prior, theta_keys, search):
H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
zip(theta_vals, theta_keys)]
return np.sum(H)
def logl(self, theta, search):
for j, theta_i in enumerate(self.theta_idxs):
self.fixed_theta[theta_i] = theta[j]
FS = search.compute_nseg_fstat(*self.fixed_theta)
return FS
class MCMCTransientSearch(MCMCSearch): class MCMCTransientSearch(MCMCSearch):
""" MCMC search for a transient signal using the ComputeFstat """ """ MCMC search for a transient signal using the ComputeFstat """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment