Skip to content
Snippets Groups Projects
Commit 5e1bfa80 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Various improvements

- Add option to just update the search object in the follow-up
- Abstract loading in SFT data and check it is not already loaded before
  creating SFTCatalog
- Fix bug in plot_walkers in which the number of steps is not displaying
- Removal of unused code
- Fixes to docs
parent 8e931a02
No related branches found
No related tags found
No related merge requests found
......@@ -341,9 +341,9 @@ class ComputeFstat(object):
self.init_computefstatistic_single_point()
def init_computefstatistic_single_point(self):
""" Initilisation step of run_computefstatistic for a single point """
def get_SFTCatalog(self):
if hasattr(self, 'SFTCatalog'):
return
logging.info('Initialising SFTCatalog')
constraints = lalpulsar.SFTConstraints()
if self.detector:
......@@ -376,6 +376,12 @@ class ComputeFstat(object):
int(SFT_timestamps[-1]),
subprocess.check_output('lalapps_tconvert {}'.format(
int(SFT_timestamps[-1])), shell=True).rstrip('\n')))
self.SFTCatalog = SFTCatalog
def init_computefstatistic_single_point(self):
""" Initilisation step of run_computefstatistic for a single point """
self.get_SFTCatalog()
logging.info('Initialising ephems')
ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem)
......@@ -412,24 +418,22 @@ class ComputeFstat(object):
PP.Doppler.fkdot = np.array(self.injectSources['fkdot'])
PP.Doppler.refTime = self.tref
if 't0' not in self.injectSources:
#PP.Transient.t0 = int(self.minStartTime)
#PP.Transient.tau = int(self.maxStartTime - self.minStartTime)
PP.Transient.type = lalpulsar.TRANSIENT_NONE
FstatOAs.injectSources = PPV
else:
FstatOAs.injectSources = lalpulsar.FstatOptionalArgsDefaults.injectSources
if self.minCoverFreq is None or self.maxCoverFreq is None:
fAs = [d.header.f0 for d in SFTCatalog.data]
fAs = [d.header.f0 for d in self.SFTCatalog.data]
fBs = [d.header.f0 + (d.numBins-1)*d.header.deltaF
for d in SFTCatalog.data]
for d in self.SFTCatalog.data]
self.minCoverFreq = np.min(fAs) + 0.5
self.maxCoverFreq = np.max(fBs) - 0.5
logging.info('Min/max cover freqs not provided, using '
'{} and {}, est. from SFTs'.format(
self.minCoverFreq, self.maxCoverFreq))
self.FstatInput = lalpulsar.CreateFstatInput(SFTCatalog,
self.FstatInput = lalpulsar.CreateFstatInput(self.SFTCatalog,
self.minCoverFreq,
self.maxCoverFreq,
dFreq,
......@@ -449,7 +453,7 @@ class ComputeFstat(object):
self.FstatResults = lalpulsar.FstatResults()
if self.BSGL:
if len(names) < 2:
if len(self.names) < 2:
raise ValueError("Can't use BSGL with single detector data")
else:
logging.info('Initialising BSGL')
......@@ -811,7 +815,7 @@ class MCMCSearch(BaseSearchClass):
""" MCMC search using ComputeFstat"""
@initializer
def __init__(self, label, outdir, sftfilepath, theta_prior, tref,
minStartTime, maxStartTime, nsteps=[100, 100, 100],
minStartTime, maxStartTime, nsteps=[100, 100],
nwalkers=100, ntemps=1, log10temperature_min=-5,
theta_initial=None, scatter_val=1e-10,
binary=False, BSGL=False, minCoverFreq=None,
......@@ -1323,14 +1327,14 @@ class MCMCSearch(BaseSearchClass):
if fig is None and axes is None:
fig = plt.figure(figsize=(8, 4*ndim))
ax = fig.add_subplot(ndim+1, 1, 1)
axes = [ax] + [fig.add_subplot(ndim+1, 1, i, sharex=ax)
axes = [ax] + [fig.add_subplot(ndim+1, 1, i)
for i in range(2, ndim+1)]
idxs = np.arange(chain.shape[1])
if ndim > 1:
for i in range(ndim):
axes[i].ticklabel_format(useOffset=False, axis='y')
if i < ndim:
if i < ndim-1:
axes[i].set_xticklabels([])
cs = chain[:, :, i].T
if burnin_idx:
......@@ -1749,16 +1753,17 @@ class MCMCGlitchSearch(MCMCSearch):
""" MCMC search using the SemiCoherentGlitchSearch """
@initializer
def __init__(self, label, outdir, sftfilepath, theta_prior, tref,
minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100, 100],
minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100],
nwalkers=100, ntemps=1, log10temperature_min=-5,
theta_initial=None, scatter_val=1e-10, dtglitchmin=1*86400,
theta0_idx=0, detector=None, BSGL=False, minCoverFreq=None,
maxCoverFreq=None, earth_ephem=None, sun_ephem=None):
"""
Parameters
----------
label, outdir: str
A label and directory to read/write data from/to
_ sftfilepath: str
sftfilepath: str
File patern to match SFTs
theta_prior: dict
Dictionary of priors and fixed values for the search parameters.
......@@ -2066,6 +2071,10 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
BSGL=self.BSGL, run_setup=self.run_setup)
return d
def update_search_object(self):
logging.info('Update search object')
self.search.init_computefstatistic_single_point()
def get_width_from_prior(self, prior, key):
if prior[key]['type'] == 'unif':
return prior[key]['upper'] - prior[key]['lower']
......@@ -2230,7 +2239,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
self.nsegs = nseg
self.search.nsegs = nseg
self.inititate_search_object()
self.update_search_object()
self.search.init_semicoherent_parameters()
sampler = emcee.PTSampler(
self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
......@@ -2255,12 +2264,6 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
fig=fig, axes=axes, burnin_idx=nburn,
xoffset=nsteps_total, **kwargs)
#yvals = axes[0].get_ylim()
#axes[0].annotate(
# #r'$T_{{\rm coh}}^{{\rm (days)}}{{=}}{:1.1f}$'.format(Tcoh),
# r'{}'.format(j),
# xy=(nsteps_total, yvals[0]*(1+1e-2*(yvals[1]-yvals[0])/yvals[1])),
# fontsize=6)
for ax in axes[:-1]:
ax.axvline(nsteps_total, color='k', ls='--')
nsteps_total += nburn+nprod
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment