Commit a42f0b0a authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Minor improvements to user interface

- Remove minStartTime, maxStartTime and outdir as default arguments
- Adds notes ot documentation on which arguments are optional
- Change log-level on some command line calls that dont' matter
- Reorganise the tests performed after loading the data
- If minStartTime and maxStartTime are None, set them using
  SFT_timestamps
- Remove default labels from plot_twoF_cumulative
- If add_pfs is called, call generate_loudest automatically
- Save min/maxStartTime in pickle and load if required
parent 2adef425
...@@ -423,12 +423,10 @@ class ComputeFstat(BaseSearchClass): ...@@ -423,12 +423,10 @@ class ComputeFstat(BaseSearchClass):
constraints.minStartTime = lal.LIGOTimeGPS(self.minStartTime) constraints.minStartTime = lal.LIGOTimeGPS(self.minStartTime)
if self.maxStartTime: if self.maxStartTime:
constraints.maxStartTime = lal.LIGOTimeGPS(self.maxStartTime) constraints.maxStartTime = lal.LIGOTimeGPS(self.maxStartTime)
logging.info('Loading data matching pattern {}'.format( logging.info('Loading data matching pattern {}'.format(
self.sftfilepattern)) self.sftfilepattern))
SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepattern, constraints) SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepattern, constraints)
detector_names = list(set([d.header.name for d in SFTCatalog.data]))
self.detector_names = detector_names
SFT_timestamps = [d.header.epoch for d in SFTCatalog.data] SFT_timestamps = [d.header.epoch for d in SFTCatalog.data]
self.SFT_timestamps = [float(s) for s in SFT_timestamps] self.SFT_timestamps = [float(s) for s in SFT_timestamps]
if len(SFT_timestamps) == 0: if len(SFT_timestamps) == 0:
...@@ -440,21 +438,33 @@ class ComputeFstat(BaseSearchClass): ...@@ -440,21 +438,33 @@ class ComputeFstat(BaseSearchClass):
plot_hist(SFT_timestamps, height=5, bincount=50) plot_hist(SFT_timestamps, height=5, bincount=50)
except ImportError: except ImportError:
pass pass
if len(detector_names) == 0:
raise ValueError('No data loaded.')
logging.info('Loaded {} data files from detectors {}'.format(
len(SFT_timestamps), detector_names))
cl_tconv1 = 'lalapps_tconvert {}'.format(int(SFT_timestamps[0])) cl_tconv1 = 'lalapps_tconvert {}'.format(int(SFT_timestamps[0]))
output = helper_functions.run_commandline(cl_tconv1) output = helper_functions.run_commandline(cl_tconv1,
log_level=logging.DEBUG)
tconvert1 = output.rstrip('\n') tconvert1 = output.rstrip('\n')
cl_tconv2 = 'lalapps_tconvert {}'.format(int(SFT_timestamps[-1])) cl_tconv2 = 'lalapps_tconvert {}'.format(int(SFT_timestamps[-1]))
output = helper_functions.run_commandline(cl_tconv2) output = helper_functions.run_commandline(cl_tconv2,
log_level=logging.DEBUG)
tconvert2 = output.rstrip('\n') tconvert2 = output.rstrip('\n')
logging.info('Data spans from {} ({}) to {} ({})'.format( logging.info('Data spans from {} ({}) to {} ({})'.format(
int(SFT_timestamps[0]), int(SFT_timestamps[0]),
tconvert1, tconvert1,
int(SFT_timestamps[-1]), int(SFT_timestamps[-1]),
tconvert2)) tconvert2))
if self.minStartTime is None:
self.minStartTime = int(SFT_timestamps[0])
if self.maxStartTime is None:
self.maxStartTime = int(SFT_timestamps[-1])
detector_names = list(set([d.header.name for d in SFTCatalog.data]))
self.detector_names = detector_names
if len(detector_names) == 0:
raise ValueError('No data loaded.')
logging.info('Loaded {} data files from detectors {}'.format(
len(SFT_timestamps), detector_names))
return SFTCatalog return SFTCatalog
def init_computefstatistic_single_point(self): def init_computefstatistic_single_point(self):
...@@ -735,7 +745,7 @@ class ComputeFstat(BaseSearchClass): ...@@ -735,7 +745,7 @@ class ComputeFstat(BaseSearchClass):
def plot_twoF_cumulative(self, label, outdir, add_pfs=False, N=15, def plot_twoF_cumulative(self, label, outdir, add_pfs=False, N=15,
injectSources=None, ax=None, c='k', savefig=True, injectSources=None, ax=None, c='k', savefig=True,
title=None, **kwargs): title=None, plt_label=None, **kwargs):
""" Plot the twoF value cumulatively """ Plot the twoF value cumulatively
Parameters Parameters
...@@ -753,8 +763,8 @@ class ComputeFstat(BaseSearchClass): ...@@ -753,8 +763,8 @@ class ComputeFstat(BaseSearchClass):
Colour Colour
savefig : bool savefig : bool
If true, save the figure in outdir If true, save the figure in outdir
title: str title, plt_label: str
Figure title Figure title and label
Returns Returns
------- -------
...@@ -775,7 +785,7 @@ class ComputeFstat(BaseSearchClass): ...@@ -775,7 +785,7 @@ class ComputeFstat(BaseSearchClass):
pfs_input = None pfs_input = None
taus, twoFs = self.calculate_twoF_cumulative(**kwargs) taus, twoFs = self.calculate_twoF_cumulative(**kwargs)
ax.plot(taus/86400., twoFs, label='All detectors', color=c) ax.plot(taus/86400., twoFs, label=plt_label, color=c)
if len(self.detector_names) > 1: if len(self.detector_names) > 1:
detector_names = self.detector_names detector_names = self.detector_names
detectors = self.detectors detectors = self.detectors
...@@ -819,7 +829,8 @@ class ComputeFstat(BaseSearchClass): ...@@ -819,7 +829,8 @@ class ComputeFstat(BaseSearchClass):
else: else:
ax.set_ylabel(r'$\widetilde{2\mathcal{F}}_{\rm cumulative}$') ax.set_ylabel(r'$\widetilde{2\mathcal{F}}_{\rm cumulative}$')
ax.set_xlim(0, taus[-1]/86400) ax.set_xlim(0, taus[-1]/86400)
ax.legend(frameon=False, loc=2, fontsize=6) if plt_label:
ax.legend(frameon=False, loc=2, fontsize=6)
if title: if title:
ax.set_title(title) ax.set_title(title)
if savefig: if savefig:
......
...@@ -26,52 +26,55 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -26,52 +26,55 @@ class MCMCSearch(core.BaseSearchClass):
Parameters Parameters
---------- ----------
label, outdir: str
A label and directory to read/write data from/to
theta_prior: dict theta_prior: dict
Dictionary of priors and fixed values for the search parameters. Dictionary of priors and fixed values for the search parameters.
For each parameters (key of the dict), if it is to be held fixed For each parameters (key of the dict), if it is to be held fixed
the value should be the constant float, if it is be searched, the the value should be the constant float, if it is be searched, the
value should be a dictionary of the prior. value should be a dictionary of the prior.
tref, minStartTime, maxStartTime: int tref, minStartTime, maxStartTime: int
GPS seconds of the reference time, start time and end time GPS seconds of the reference time, start time and end time. While tref
sftfilepattern: str is requirede, minStartTime and maxStartTime default to None in which
case all available data is used.
label, outdir: str
A label and output directory (optional, defaults is `'data'`) to
name files
sftfilepattern: str, optional
Pattern to match SFTs using wildcards (*?) and ranges [0-9]; Pattern to match SFTs using wildcards (*?) and ranges [0-9];
mutiple patterns can be given separated by colons. mutiple patterns can be given separated by colons.
detectors: str detectors: str, optional
Two character reference to the detectors to use, specify None for no Two character reference to the detectors to use, specify None for no
contraint and comma separate for multiple references. contraint and comma separate for multiple references.
nsteps: list (2,) nsteps: list (2,), optional
Number of burn-in and production steps to take, [nburn, nprod]. See Number of burn-in and production steps to take, [nburn, nprod]. See
`pyfstat.MCMCSearch.setup_initialisation()` for details on adding `pyfstat.MCMCSearch.setup_initialisation()` for details on adding
initialisation steps. initialisation steps.
nwalkers, ntemps: int, nwalkers, ntemps: int, optional
The number of walkers and temperates to use in the parallel The number of walkers and temperates to use in the parallel
tempered PTSampler. tempered PTSampler.
log10beta_min float < 0 log10beta_min float < 0, optional
The log_10(beta) value, if given the set of betas passed to PTSampler The log_10(beta) value, if given the set of betas passed to PTSampler
are generated from `np.logspace(0, log10beta_min, ntemps)` (given are generated from `np.logspace(0, log10beta_min, ntemps)` (given
in descending order to emcee). in descending order to emcee).
theta_initial: dict, array, (None) theta_initial: dict, array, optional
A dictionary of distribution about which to distribute the A dictionary of distribution about which to distribute the
initial walkers about initial walkers about
rhohatmax: float, rhohatmax: float, optional
Upper bound for the SNR scale parameter (required to normalise the Upper bound for the SNR scale parameter (required to normalise the
Bayes factor) - this needs to be carefully set when using the Bayes factor) - this needs to be carefully set when using the
evidence. evidence.
binary: bool binary: bool, optional
If true, search over binary parameters If true, search over binary parameters
BSGL: bool BSGL: bool, optional
If true, use the BSGL statistic If true, use the BSGL statistic
SSBPrec: int SSBPrec: int, optional
SSBPrec (SSB precision) to use when calling ComputeFstat SSBPrec (SSB precision) to use when calling ComputeFstat
minCoverFreq, maxCoverFreq: float minCoverFreq, maxCoverFreq: float, optional
Minimum and maximum instantaneous frequency which will be covered Minimum and maximum instantaneous frequency which will be covered
over the SFT time span as passed to CreateFstatInput over the SFT time span as passed to CreateFstatInput
injectSources: dict injectSources: dict, optional
If given, inject these properties into the SFT files before running If given, inject these properties into the SFT files before running
the search the search
assumeSqrtSX: float assumeSqrtSX: float, optional
Don't estimate noise-floors, but assume (stationary) per-IFO sqrt{SX} Don't estimate noise-floors, but assume (stationary) per-IFO sqrt{SX}
Attributes Attributes
...@@ -99,9 +102,9 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -99,9 +102,9 @@ class MCMCSearch(core.BaseSearchClass):
transform_dictionary = {} transform_dictionary = {}
@helper_functions.initializer @helper_functions.initializer
def __init__(self, label, outdir, theta_prior, tref, minStartTime, def __init__(self, theta_prior, tref, label, outdir='data',
maxStartTime, sftfilepattern=None, detectors=None, minStartTime=None, maxStartTime=None, sftfilepattern=None,
nsteps=[100, 100], nwalkers=100, ntemps=1, detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1,
log10beta_min=-5, theta_initial=None, log10beta_min=-5, theta_initial=None,
rhohatmax=1000, binary=False, BSGL=False, rhohatmax=1000, binary=False, BSGL=False,
SSBprec=None, minCoverFreq=None, maxCoverFreq=None, SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
...@@ -151,6 +154,10 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -151,6 +154,10 @@ class MCMCSearch(core.BaseSearchClass):
minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
binary=self.binary, injectSources=self.injectSources, binary=self.binary, injectSources=self.injectSources,
assumeSqrtSX=self.assumeSqrtSX, SSBprec=self.SSBprec) assumeSqrtSX=self.assumeSqrtSX, SSBprec=self.SSBprec)
if self.minStartTime is None:
self.minStartTime = self.search.minStartTime
if self.maxStartTime is None:
self.maxStartTime = self.search.maxStartTime
def logp(self, theta_vals, theta_prior, theta_keys, search): def logp(self, theta_vals, theta_prior, theta_keys, search):
H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
...@@ -830,6 +837,9 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -830,6 +837,9 @@ class MCMCSearch(core.BaseSearchClass):
if key not in d: if key not in d:
d[key] = val d[key] = val
if 'add_pfs' in kwargs:
self.generate_loudest()
if hasattr(self, 'search') is False: if hasattr(self, 'search') is False:
self._initiate_search_object() self._initiate_search_object()
if self.binary is False: if self.binary is False:
...@@ -1165,7 +1175,8 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -1165,7 +1175,8 @@ class MCMCSearch(core.BaseSearchClass):
ntemps=self.ntemps, theta_keys=self.theta_keys, ntemps=self.ntemps, theta_keys=self.theta_keys,
theta_prior=self.theta_prior, theta_prior=self.theta_prior,
log10beta_min=self.log10beta_min, log10beta_min=self.log10beta_min,
BSGL=self.BSGL) BSGL=self.BSGL, minStartTime=self.minStartTime,
maxStartTime=self.maxStartTime)
return d return d
def _save_data(self, sampler, samples, lnprobs, lnlikes, all_lnlikelihood): def _save_data(self, sampler, samples, lnprobs, lnlikes, all_lnlikelihood):
...@@ -1212,6 +1223,11 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -1212,6 +1223,11 @@ class MCMCSearch(core.BaseSearchClass):
old_d.pop('lnlikes') old_d.pop('lnlikes')
old_d.pop('all_lnlikelihood') old_d.pop('all_lnlikelihood')
for key in 'minStartTime', 'maxStartTime':
if new_d[key] is None:
new_d[key] = old_d[key]
setattr(self, key, new_d[key])
mod_keys = [] mod_keys = []
for key in new_d.keys(): for key in new_d.keys():
if key in old_d: if key in old_d:
...@@ -1569,9 +1585,9 @@ class MCMCGlitchSearch(MCMCSearch): ...@@ -1569,9 +1585,9 @@ class MCMCGlitchSearch(MCMCSearch):
) )
@helper_functions.initializer @helper_functions.initializer
def __init__(self, label, outdir, theta_prior, tref, minStartTime, def __init__(self, theta_prior, tref, label, outdir='data',
maxStartTime, sftfilepattern=None, detectors=None, minStartTime=None, maxStartTime=None, sftfilepattern=None,
nsteps=[100, 100], nwalkers=100, ntemps=1, detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1,
log10beta_min=-5, theta_initial=None, log10beta_min=-5, theta_initial=None,
rhohatmax=1000, binary=False, BSGL=False, rhohatmax=1000, binary=False, BSGL=False,
SSBprec=None, minCoverFreq=None, maxCoverFreq=None, SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
...@@ -1610,6 +1626,10 @@ class MCMCGlitchSearch(MCMCSearch): ...@@ -1610,6 +1626,10 @@ class MCMCGlitchSearch(MCMCSearch):
minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq, minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
detectors=self.detectors, BSGL=self.BSGL, nglitch=self.nglitch, detectors=self.detectors, BSGL=self.BSGL, nglitch=self.nglitch,
theta0_idx=self.theta0_idx, injectSources=self.injectSources) theta0_idx=self.theta0_idx, injectSources=self.injectSources)
if self.minStartTime is None:
self.minStartTime = self.search.minStartTime
if self.maxStartTime is None:
self.maxStartTime = self.search.maxStartTime
def logp(self, theta_vals, theta_prior, theta_keys, search): def logp(self, theta_vals, theta_prior, theta_keys, search):
if self.nglitch > 1: if self.nglitch > 1:
...@@ -1778,9 +1798,9 @@ class MCMCSemiCoherentSearch(MCMCSearch): ...@@ -1778,9 +1798,9 @@ class MCMCSemiCoherentSearch(MCMCSearch):
""" """
@helper_functions.initializer @helper_functions.initializer
def __init__(self, label, outdir, theta_prior, tref, minStartTime, def __init__(self, theta_prior, tref, label, outdir='data',
maxStartTime, sftfilepattern=None, detectors=None, minStartTime=None, maxStartTime=None, sftfilepattern=None,
nsteps=[100, 100], nwalkers=100, ntemps=1, detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1,
log10beta_min=-5, theta_initial=None, log10beta_min=-5, theta_initial=None,
rhohatmax=1000, binary=False, BSGL=False, rhohatmax=1000, binary=False, BSGL=False,
SSBprec=None, minCoverFreq=None, maxCoverFreq=None, SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
...@@ -1830,6 +1850,10 @@ class MCMCSemiCoherentSearch(MCMCSearch): ...@@ -1830,6 +1850,10 @@ class MCMCSemiCoherentSearch(MCMCSearch):
maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq, maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
maxCoverFreq=self.maxCoverFreq, detectors=self.detectors, maxCoverFreq=self.maxCoverFreq, detectors=self.detectors,
injectSources=self.injectSources, assumeSqrtSX=self.assumeSqrtSX) injectSources=self.injectSources, assumeSqrtSX=self.assumeSqrtSX)
if self.minStartTime is None:
self.minStartTime = self.search.minStartTime
if self.maxStartTime is None:
self.maxStartTime = self.search.maxStartTime
def logp(self, theta_vals, theta_prior, theta_keys, search): def logp(self, theta_vals, theta_prior, theta_keys, search):
H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
...@@ -2144,6 +2168,10 @@ class MCMCTransientSearch(MCMCSearch): ...@@ -2144,6 +2168,10 @@ class MCMCTransientSearch(MCMCSearch):
minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
BSGL=self.BSGL, binary=self.binary, BSGL=self.BSGL, binary=self.binary,
injectSources=self.injectSources) injectSources=self.injectSources)
if self.minStartTime is None:
self.minStartTime = self.search.minStartTime
if self.maxStartTime is None:
self.maxStartTime = self.search.maxStartTime
def logl(self, theta, search): def logl(self, theta, search):
for j, theta_i in enumerate(self.theta_idxs): for j, theta_i in enumerate(self.theta_idxs):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment