diff --git a/pyfstat/core.py b/pyfstat/core.py index d2ff3c7043b1cc278aec8649fa52a5152a75b177..2388069f69dcda6e1720d9f9bdc22dbce1d48526 100755 --- a/pyfstat/core.py +++ b/pyfstat/core.py @@ -48,7 +48,8 @@ def get_dictionary_from_lines(lines): def predict_fstat(h0, cosi, psi, Alpha, Delta, Freq, sftfilepattern, - minStartTime, maxStartTime, IFO=None, assumeSqrtSX=None): + minStartTime, maxStartTime, IFO=None, assumeSqrtSX=None, + **kwargs): """ Wrapper to lalapps_PredictFstat """ c_l = [] c_l.append("lalapps_PredictFstat") @@ -572,9 +573,17 @@ class ComputeFstat(object): return times, pfs, pfs_sigma def plot_twoF_cumulative(self, label, outdir, ax=None, c='k', savefig=True, - title=None, add_pfs=False, N=15, **kwargs): + title=None, add_pfs=False, N=15, + injectSources=None, **kwargs): if ax is None: fig, ax = plt.subplots() + if injectSources: + pfs_input = dict( + h0=injectSources['h0'], cosi=injectSources['cosi'], + psi=injectSources['psi'], Alpha=injectSources['Alpha'], + Delta=injectSources['Delta'], Freq=injectSources['fkdot'][0]) + else: + pfs_input = None taus, twoFs = self.calculate_twoF_cumulative(**kwargs) ax.plot(taus/86400., twoFs, label='All detectors', color=c) @@ -591,7 +600,8 @@ class ComputeFstat(object): self.detector_names = detector_names if add_pfs: - times, pfs, pfs_sigma = self.calculate_pfs(label, outdir, N=N) + times, pfs, pfs_sigma = self.calculate_pfs( + label, outdir, N=N, pfs_input=pfs_input) ax.fill_between( (times-self.minStartTime)/86400., pfs-pfs_sigma, pfs+pfs_sigma, color=c, @@ -600,7 +610,7 @@ class ComputeFstat(object): if len(self.detector_names) > 1: for d in self.detector_names: times, pfs, pfs_sigma = self.calculate_pfs( - label, outdir, IFO=d.upper(), N=N) + label, outdir, IFO=d.upper(), N=N, pfs_input=pfs_input) ax.fill_between( (times-self.minStartTime)/86400., pfs-pfs_sigma, pfs+pfs_sigma, color=detector_colors[d.lower()], @@ -764,7 +774,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): nglitch=0, sftfilepath=None, theta0_idx=0, BSGL=False, minCoverFreq=None, maxCoverFreq=None, assumeSqrtSX=None, detectors=None, earth_ephem=None, sun_ephem=None, - SSBprec=None): + SSBprec=None, injectSources=None): """ Parameters ---------- diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py index 91ee79ee4e51561b0fcb64c683f917f172d3dea2..7bd7a0d7c34a8596db28693c2b2d98768f9c6af6 100644 --- a/pyfstat/grid_based_searches.py +++ b/pyfstat/grid_based_searches.py @@ -337,7 +337,7 @@ class FrequencySlidingWindow(GridSearch): maxStartTime=None, window_size=10*86400, window_delta=86400, BSGL=False, minCoverFreq=None, maxCoverFreq=None, earth_ephem=None, sun_ephem=None, detectors=None, - SSBprec=None): + SSBprec=None, injectSources=None): """ Parameters ---------- @@ -377,7 +377,8 @@ class FrequencySlidingWindow(GridSearch): earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, detectors=self.detectors, transient=True, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, - BSGL=self.BSGL, SSBprec=self.SSBprec) + BSGL=self.BSGL, SSBprec=self.SSBprec, + injectSources=self.injectSources) self.search.get_det_stat = ( self.search.run_computefstatistic_single_point) diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index e34dd836873a0035b295d2a44815e67e14640eb4..34c7d013b9dec4d513561c1fcd38a6534ecc0593 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -1440,7 +1440,7 @@ class MCMCGlitchSearch(MCMCSearch): theta_initial=None, scatter_val=1e-10, rhohatmax=1000, dtglitchmin=1*86400, theta0_idx=0, detectors=None, BSGL=False, minCoverFreq=None, maxCoverFreq=None, - earth_ephem=None, sun_ephem=None): + earth_ephem=None, sun_ephem=None, injectSources=None): """ Parameters ---------- @@ -1534,7 +1534,8 @@ class MCMCGlitchSearch(MCMCSearch): maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, detectors=self.detectors, BSGL=self.BSGL, - nglitch=self.nglitch, theta0_idx=self.theta0_idx) + nglitch=self.nglitch, theta0_idx=self.theta0_idx, + injectSources=self.injectSources) def logp(self, theta_vals, theta_prior, theta_keys, search): if self.nglitch > 1: @@ -2103,7 +2104,8 @@ class MCMCTransientSearch(MCMCSearch): earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, detectors=self.detectors, transient=True, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, - BSGL=self.BSGL, binary=self.binary) + BSGL=self.BSGL, binary=self.binary, + injectSources=self.injectSources) def logl(self, theta, search): for j, theta_i in enumerate(self.theta_idxs):