diff --git a/pyfstat.py b/pyfstat.py index d32caac2a12623f71fe70c34032e39d4b7ebb77f..3ead11890333cf5b0ae101db7a074f3710a78b7f 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -341,9 +341,9 @@ class ComputeFstat(object): self.init_computefstatistic_single_point() - def init_computefstatistic_single_point(self): - """ Initilisation step of run_computefstatistic for a single point """ - + def get_SFTCatalog(self): + if hasattr(self, 'SFTCatalog'): + return logging.info('Initialising SFTCatalog') constraints = lalpulsar.SFTConstraints() if self.detector: @@ -376,6 +376,12 @@ class ComputeFstat(object): int(SFT_timestamps[-1]), subprocess.check_output('lalapps_tconvert {}'.format( int(SFT_timestamps[-1])), shell=True).rstrip('\n'))) + self.SFTCatalog = SFTCatalog + + def init_computefstatistic_single_point(self): + """ Initilisation step of run_computefstatistic for a single point """ + + self.get_SFTCatalog() logging.info('Initialising ephems') ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem) @@ -412,24 +418,22 @@ class ComputeFstat(object): PP.Doppler.fkdot = np.array(self.injectSources['fkdot']) PP.Doppler.refTime = self.tref if 't0' not in self.injectSources: - #PP.Transient.t0 = int(self.minStartTime) - #PP.Transient.tau = int(self.maxStartTime - self.minStartTime) PP.Transient.type = lalpulsar.TRANSIENT_NONE FstatOAs.injectSources = PPV else: FstatOAs.injectSources = lalpulsar.FstatOptionalArgsDefaults.injectSources if self.minCoverFreq is None or self.maxCoverFreq is None: - fAs = [d.header.f0 for d in SFTCatalog.data] + fAs = [d.header.f0 for d in self.SFTCatalog.data] fBs = [d.header.f0 + (d.numBins-1)*d.header.deltaF - for d in SFTCatalog.data] + for d in self.SFTCatalog.data] self.minCoverFreq = np.min(fAs) + 0.5 self.maxCoverFreq = np.max(fBs) - 0.5 logging.info('Min/max cover freqs not provided, using ' '{} and {}, est. from SFTs'.format( self.minCoverFreq, self.maxCoverFreq)) - self.FstatInput = lalpulsar.CreateFstatInput(SFTCatalog, + self.FstatInput = lalpulsar.CreateFstatInput(self.SFTCatalog, self.minCoverFreq, self.maxCoverFreq, dFreq, @@ -449,7 +453,7 @@ class ComputeFstat(object): self.FstatResults = lalpulsar.FstatResults() if self.BSGL: - if len(names) < 2: + if len(self.names) < 2: raise ValueError("Can't use BSGL with single detector data") else: logging.info('Initialising BSGL') @@ -811,7 +815,7 @@ class MCMCSearch(BaseSearchClass): """ MCMC search using ComputeFstat""" @initializer def __init__(self, label, outdir, sftfilepath, theta_prior, tref, - minStartTime, maxStartTime, nsteps=[100, 100, 100], + minStartTime, maxStartTime, nsteps=[100, 100], nwalkers=100, ntemps=1, log10temperature_min=-5, theta_initial=None, scatter_val=1e-10, binary=False, BSGL=False, minCoverFreq=None, @@ -1323,14 +1327,14 @@ class MCMCSearch(BaseSearchClass): if fig is None and axes is None: fig = plt.figure(figsize=(8, 4*ndim)) ax = fig.add_subplot(ndim+1, 1, 1) - axes = [ax] + [fig.add_subplot(ndim+1, 1, i, sharex=ax) + axes = [ax] + [fig.add_subplot(ndim+1, 1, i) for i in range(2, ndim+1)] idxs = np.arange(chain.shape[1]) if ndim > 1: for i in range(ndim): axes[i].ticklabel_format(useOffset=False, axis='y') - if i < ndim: + if i < ndim-1: axes[i].set_xticklabels([]) cs = chain[:, :, i].T if burnin_idx: @@ -1749,16 +1753,17 @@ class MCMCGlitchSearch(MCMCSearch): """ MCMC search using the SemiCoherentGlitchSearch """ @initializer def __init__(self, label, outdir, sftfilepath, theta_prior, tref, - minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100, 100], + minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100], nwalkers=100, ntemps=1, log10temperature_min=-5, theta_initial=None, scatter_val=1e-10, dtglitchmin=1*86400, theta0_idx=0, detector=None, BSGL=False, minCoverFreq=None, maxCoverFreq=None, earth_ephem=None, sun_ephem=None): """ Parameters + ---------- label, outdir: str A label and directory to read/write data from/to -_ sftfilepath: str + sftfilepath: str File patern to match SFTs theta_prior: dict Dictionary of priors and fixed values for the search parameters. @@ -1768,7 +1773,7 @@ _ sftfilepath: str theta_initial: dict, array, (None) Either a dictionary of distribution about which to distribute the initial walkers about, an array (from which the walkers will be - scattered by scatter_val), or None in which case the prior is used. + scattered by scatter_val), or None in which case the prior is used. scatter_val, float or ndim array Size of scatter to use about the initialisation step, if given as an array it must be of length ndim and the order is given by @@ -2066,6 +2071,10 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): BSGL=self.BSGL, run_setup=self.run_setup) return d + def update_search_object(self): + logging.info('Update search object') + self.search.init_computefstatistic_single_point() + def get_width_from_prior(self, prior, key): if prior[key]['type'] == 'unif': return prior[key]['upper'] - prior[key]['lower'] @@ -2230,7 +2239,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): self.nsegs = nseg self.search.nsegs = nseg - self.inititate_search_object() + self.update_search_object() self.search.init_semicoherent_parameters() sampler = emcee.PTSampler( self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp, @@ -2255,12 +2264,6 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols, fig=fig, axes=axes, burnin_idx=nburn, xoffset=nsteps_total, **kwargs) - #yvals = axes[0].get_ylim() - #axes[0].annotate( - # #r'$T_{{\rm coh}}^{{\rm (days)}}{{=}}{:1.1f}$'.format(Tcoh), - # r'{}'.format(j), - # xy=(nsteps_total, yvals[0]*(1+1e-2*(yvals[1]-yvals[0])/yvals[1])), - # fontsize=6) for ax in axes[:-1]: ax.axvline(nsteps_total, color='k', ls='--') nsteps_total += nburn+nprod