diff --git a/pyfstat/core.py b/pyfstat/core.py
index d2ff3c7043b1cc278aec8649fa52a5152a75b177..2388069f69dcda6e1720d9f9bdc22dbce1d48526 100755
--- a/pyfstat/core.py
+++ b/pyfstat/core.py
@@ -48,7 +48,8 @@ def get_dictionary_from_lines(lines):
 
 
 def predict_fstat(h0, cosi, psi, Alpha, Delta, Freq, sftfilepattern,
-                  minStartTime, maxStartTime, IFO=None, assumeSqrtSX=None):
+                  minStartTime, maxStartTime, IFO=None, assumeSqrtSX=None,
+                  **kwargs):
     """ Wrapper to lalapps_PredictFstat """
     c_l = []
     c_l.append("lalapps_PredictFstat")
@@ -572,9 +573,17 @@ class ComputeFstat(object):
         return times, pfs, pfs_sigma
 
     def plot_twoF_cumulative(self, label, outdir, ax=None, c='k', savefig=True,
-                             title=None, add_pfs=False, N=15, **kwargs):
+                             title=None, add_pfs=False, N=15,
+                             injectSources=None, **kwargs):
         if ax is None:
             fig, ax = plt.subplots()
+        if injectSources:
+            pfs_input = dict(
+                h0=injectSources['h0'], cosi=injectSources['cosi'],
+                psi=injectSources['psi'], Alpha=injectSources['Alpha'],
+                Delta=injectSources['Delta'], Freq=injectSources['fkdot'][0])
+        else:
+            pfs_input = None
 
         taus, twoFs = self.calculate_twoF_cumulative(**kwargs)
         ax.plot(taus/86400., twoFs, label='All detectors', color=c)
@@ -591,7 +600,8 @@ class ComputeFstat(object):
             self.detector_names = detector_names
 
         if add_pfs:
-            times, pfs, pfs_sigma = self.calculate_pfs(label, outdir, N=N)
+            times, pfs, pfs_sigma = self.calculate_pfs(
+                label, outdir, N=N, pfs_input=pfs_input)
             ax.fill_between(
                 (times-self.minStartTime)/86400., pfs-pfs_sigma, pfs+pfs_sigma,
                 color=c,
@@ -600,7 +610,7 @@ class ComputeFstat(object):
             if len(self.detector_names) > 1:
                 for d in self.detector_names:
                     times, pfs, pfs_sigma = self.calculate_pfs(
-                        label, outdir, IFO=d.upper(), N=N)
+                        label, outdir, IFO=d.upper(), N=N, pfs_input=pfs_input)
                     ax.fill_between(
                         (times-self.minStartTime)/86400., pfs-pfs_sigma,
                         pfs+pfs_sigma, color=detector_colors[d.lower()],
@@ -764,7 +774,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
                  nglitch=0, sftfilepath=None, theta0_idx=0, BSGL=False,
                  minCoverFreq=None, maxCoverFreq=None, assumeSqrtSX=None,
                  detectors=None, earth_ephem=None, sun_ephem=None,
-                 SSBprec=None):
+                 SSBprec=None, injectSources=None):
         """
         Parameters
         ----------
diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py
index 91ee79ee4e51561b0fcb64c683f917f172d3dea2..7bd7a0d7c34a8596db28693c2b2d98768f9c6af6 100644
--- a/pyfstat/grid_based_searches.py
+++ b/pyfstat/grid_based_searches.py
@@ -337,7 +337,7 @@ class FrequencySlidingWindow(GridSearch):
                  maxStartTime=None, window_size=10*86400, window_delta=86400,
                  BSGL=False, minCoverFreq=None, maxCoverFreq=None,
                  earth_ephem=None, sun_ephem=None, detectors=None,
-                 SSBprec=None):
+                 SSBprec=None, injectSources=None):
         """
         Parameters
         ----------
@@ -377,7 +377,8 @@ class FrequencySlidingWindow(GridSearch):
             earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
             detectors=self.detectors, transient=True,
             minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
-            BSGL=self.BSGL, SSBprec=self.SSBprec)
+            BSGL=self.BSGL, SSBprec=self.SSBprec,
+            injectSources=self.injectSources)
         self.search.get_det_stat = (
             self.search.run_computefstatistic_single_point)
 
diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index e34dd836873a0035b295d2a44815e67e14640eb4..34c7d013b9dec4d513561c1fcd38a6534ecc0593 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -1440,7 +1440,7 @@ class MCMCGlitchSearch(MCMCSearch):
                  theta_initial=None, scatter_val=1e-10, rhohatmax=1000,
                  dtglitchmin=1*86400, theta0_idx=0, detectors=None,
                  BSGL=False, minCoverFreq=None, maxCoverFreq=None,
-                 earth_ephem=None, sun_ephem=None):
+                 earth_ephem=None, sun_ephem=None, injectSources=None):
         """
         Parameters
         ----------
@@ -1534,7 +1534,8 @@ class MCMCGlitchSearch(MCMCSearch):
             maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
             maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
             sun_ephem=self.sun_ephem, detectors=self.detectors, BSGL=self.BSGL,
-            nglitch=self.nglitch, theta0_idx=self.theta0_idx)
+            nglitch=self.nglitch, theta0_idx=self.theta0_idx,
+            injectSources=self.injectSources)
 
     def logp(self, theta_vals, theta_prior, theta_keys, search):
         if self.nglitch > 1:
@@ -2103,7 +2104,8 @@ class MCMCTransientSearch(MCMCSearch):
             earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
             detectors=self.detectors, transient=True,
             minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
-            BSGL=self.BSGL, binary=self.binary)
+            BSGL=self.BSGL, binary=self.binary,
+            injectSources=self.injectSources)
 
     def logl(self, theta, search):
         for j, theta_i in enumerate(self.theta_idxs):