diff --git a/README.md b/README.md
index 60152e7effb4a58fb6bcfc23e51a9f73c055a276..6961c63adc5ecb6d8c86e5d0ad882e000a970d7d 100644
--- a/README.md
+++ b/README.md
@@ -16,6 +16,7 @@ to have run the [script to generate fake data](examples/make_fake_data.py).
 * [Fully-coherent MCMC search on data containing a single glitch](docs/fully_coherent_search_using_MCMC_on_glitching_data.md)
 * [Semi-coherent MCMC glitch-search on data containing a single glitch](docs/semi_coherent_glitch_search_using_MCMC_on_glitching_data.md)
 * [Semi-coherent MCMC glitch-search on data containing two glitches](docs/semi_coherent_glitch_search_with_two_glitches_using_MCMC_on_glitching_data.md)
+* [Semi-coherent Follow-Up MCMC search (dynamically changing the coherence time)](docs/follow_up.md)
 
 ## Installation
 
diff --git a/docs/follow_up.md b/docs/follow_up.md
new file mode 100644
index 0000000000000000000000000000000000000000..fbdfb421423ef16142106d1f3c85de9e6e69033c
--- /dev/null
+++ b/docs/follow_up.md
@@ -0,0 +1,85 @@
+# Semi-coherent Follow-Up MCMC search (dynamically changing the coherence time)
+
+Here, we will show the set-up for using the `MCMCFollowUp` class. The basic
+idea is to start the MCMC chains searching on a likelihood with a short coherence
+time; once the MCMC chains converge to the solution, the coherence time is
+extended effectively narrowing the peak, afterwhich the chains again converge
+to this narrower peak. The advantages of such a method are:
+
+1. Dynamically shows the evolution
+2. Able to handle multiple peaks and hence can result in a multi-modal posterior
+ 
+The plots here are produced by [follow_up.py](../example/follow_up.py).  We
+will run the search on the `basic` data generated in the
+[make_fake_data](make_fake_data.md) example. The basic script is here:
+
+```python
+import pyfstat
+
+F0 = 30.0
+F1 = -1e-10
+F2 = 0
+Alpha = 5e-3
+Delta = 6e-2
+tref = 362750407.0
+
+tstart = 1000000000
+duration = 100*86400
+tend = tstart + duration
+
+theta_prior = {'F0': {'type': 'unif', 'lower': F0*(1-1e-6), 'upper': F0*(1+1e-5)},
+               'F1': {'type': 'unif', 'lower': F1*(1+1e-2), 'upper': F1*(1-1e-2)},
+               'F2': F2,
+               'Alpha': Alpha,
+               'Delta': Delta
+               }
+
+ntemps = 1
+log10temperature_min = -1
+nwalkers = 100
+run_setup = [(1000, 50), (1000, 25), (1000, 1, False), 
+             ((500, 500), 1, True)]
+
+mcmc = pyfstat.MCMCFollowUpSearch(
+    label='follow_up', outdir='data',
+    sftfilepath='data/*basic*sft', theta_prior=theta_prior, tref=tref,
+    minStartTime=tstart, maxStartTime=tend, nwalkers=nwalkers,
+    ntemps=ntemps, log10temperature_min=log10temperature_min)
+mcmc.run(run_setup)
+mcmc.plot_corner(add_prior=True)
+mcmc.print_summary()
+```
+
+Note that, We use the `MCMCFOllowUpSearch class. Rather than using the `nsteps`
+parameter to define how long the chains are run for, this class uses
+`run_setup`. This is an `nstage` length list (or array) which determines the
+number of steps, how many steps should be considered burn-in, the number of
+segments to use, and whether to re-initialise the walkers. Each element of the
+list is a 3-tuple of the form `((nburn, nprod), nsegs, resetp0)`. However, each
+element can be given as a short form: either ommiting the `nsteps as `(nburn,
+nsegs, resetp0)` or omiting the `resetp0` as `((nburn, nsteps), nsegs)`, or
+a combination of the two. For example, above we used
+
+```python
+run_setup = [(1000, 50), (1000, 25), (1000, 1, False), 
+             ((500, 500), 1, True)]
+```
+Here we run the first two steps with 1000 burn-in steps (such that they will
+be discarded) and changing the number of segments, then one 1000 burn-in
+steps fully coherently and finally a run with 500 burn-in and 500 production
+samples and a reset of the parameters at the begining. The output is
+demonstrated here:
+
+![](img/follow_up_walkers.png)
+
+Some things to note:
+* The `resetp0` is useful to clean-up straggling walkers, but will remove all
+multimodality potentially loosing information.
+* On the first axis the coherence time is displayed for each section
+* In this example the signal is quite strong and in Gaussian noise
+* The `twoF` distribution plotted at the bottom is taken only from the production
+run
+* This plot is generated after each stage of the run - this can be useful to
+check it is converging before continuing the simulation
+
+
diff --git a/docs/img/follow_up_walkers.png b/docs/img/follow_up_walkers.png
index 70643f6882927a50fd7e527bec389ee51417c6f0..a70f881af4d62d7996a678137373c489db150473 100644
Binary files a/docs/img/follow_up_walkers.png and b/docs/img/follow_up_walkers.png differ
diff --git a/examples/follow_up.py b/examples/follow_up.py
index d76dc6d61a868e450aca2b4dc813487bca981a1a..cfd8575952dbc0c4c88c6cb2c85a021e2f5db4d5 100644
--- a/examples/follow_up.py
+++ b/examples/follow_up.py
@@ -21,8 +21,8 @@ theta_prior = {'F0': {'type': 'unif', 'lower': F0*(1-1e-6), 'upper': F0*(1+1e-5)
 ntemps = 1
 log10temperature_min = -1
 nwalkers = 100
-run_setup = [(500, 50), (500, 25), (100, 1, False), 
-             ((100, 100), 1, True)]
+run_setup = [(1000, 50), (1000, 25), (1000, 1, False), 
+             ((500, 500), 1, True)]
 
 mcmc = pyfstat.MCMCFollowUpSearch(
     label='follow_up', outdir='data',
diff --git a/pyfstat.py b/pyfstat.py
index 8aadc383e12cb3cfd8d7d05be20ce7f2bf06c3e9..3412bc906cd087d00d8be1b34d10a316af75f032 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -199,7 +199,7 @@ class ComputeFstat(object):
     def __init__(self, tref, sftfilepath=None, minStartTime=None,
                  maxStartTime=None, binary=False, transient=True, BSGL=False,
                  detector=None, minCoverFreq=None, maxCoverFreq=None,
-                 earth_ephem=None, sun_ephem=None,
+                 earth_ephem=None, sun_ephem=None, injectSources=None
                  ):
         """
         Parameters
@@ -270,7 +270,37 @@ class ComputeFstat(object):
         else:
             self.whatToCompute = lalpulsar.FSTATQ_2F
 
-        FstatOptionalArgs = lalpulsar.FstatOptionalArgsDefaults
+        FstatOAs = lalpulsar.FstatOptionalArgs()
+        FstatOAs.randSeed = lalpulsar.FstatOptionalArgsDefaults.randSeed
+        FstatOAs.SSBprec = lalpulsar.FstatOptionalArgsDefaults.SSBprec
+        FstatOAs.Dterms = lalpulsar.FstatOptionalArgsDefaults.Dterms
+        FstatOAs.runningMedianWindow = lalpulsar.FstatOptionalArgsDefaults.runningMedianWindow
+        FstatOAs.FstatMethod = lalpulsar.FstatOptionalArgsDefaults.FstatMethod
+        FstatOAs.InjectSqrtSX = lalpulsar.FstatOptionalArgsDefaults.injectSqrtSX
+        FstatOAs.assumeSqrtSX = lalpulsar.FstatOptionalArgsDefaults.assumeSqrtSX
+        FstatOAs.prevInput = lalpulsar.FstatOptionalArgsDefaults.prevInput
+        FstatOAs.collectTiming = lalpulsar.FstatOptionalArgsDefaults.collectTiming
+
+        if type(self.injectSources) == dict:
+            logging.info('Injecting source with params: {}'.format(
+                self.injectSources))
+            PPV = lalpulsar.CreatePulsarParamsVector(1)
+            PP = PPV.data[0]
+            PP.Amp.h0 = self.injectSources['h0']
+            PP.Amp.cosi = self.injectSources['cosi']
+            PP.Amp.phi0 = self.injectSources['phi0']
+            PP.Amp.psi = self.injectSources['psi']
+            PP.Doppler.Alpha = self.injectSources['Alpha']
+            PP.Doppler.Delta = self.injectSources['Delta']
+            PP.Doppler.fkdot = np.array(self.injectSources['fkdot'])
+            PP.Doppler.refTime = self.tref
+            if 't0' not in self.injectSources:
+                #PP.Transient.t0 = int(self.minStartTime)
+                #PP.Transient.tau = int(self.maxStartTime - self.minStartTime)
+                PP.Transient.type = lalpulsar.TRANSIENT_NONE
+            FstatOAs.injectSources = PPV
+        else:
+            FstatOAs.injectSources = lalpulsar.FstatOptionalArgsDefaults.injectSources
 
         if self.minCoverFreq is None or self.maxCoverFreq is None:
             fA = SFTCatalog.data[0].header.f0
@@ -287,7 +317,7 @@ class ComputeFstat(object):
                                                      self.maxCoverFreq,
                                                      dFreq,
                                                      ephems,
-                                                     FstatOptionalArgs
+                                                     FstatOAs
                                                      )
 
         logging.info('Initialising PulsarDoplerParams')
@@ -447,7 +477,8 @@ class SemiCoherentSearch(BaseSearchClass, ComputeFstat):
     def __init__(self, label, outdir, tref, nsegs=None, sftfilepath=None,
                  binary=False, BSGL=False, minStartTime=None,
                  maxStartTime=None, minCoverFreq=None, maxCoverFreq=None,
-                 detector=None, earth_ephem=None, sun_ephem=None):
+                 detector=None, earth_ephem=None, sun_ephem=None,
+                 injectSources=None):
         """
         Parameters
         ----------
@@ -476,9 +507,8 @@ class SemiCoherentSearch(BaseSearchClass, ComputeFstat):
         logging.info(('Initialising semicoherent parameters from {} to {} in'
                       ' {} segments').format(
             self.minStartTime, self.maxStartTime, self.nsegs))
-        if self.nsegs == 1:
-            self.transient = False
-            self.whatToCompute = lalpulsar.FSTATQ_2F
+        self.transient = True
+        self.whatToCompute = lalpulsar.FSTATQ_2F+lalpulsar.FSTATQ_ATOMS_PER_DET
         self.tboundaries = np.linspace(self.minStartTime, self.maxStartTime,
                                        self.nsegs+1)
 
@@ -658,7 +688,7 @@ class MCMCSearch(BaseSearchClass):
                  log10temperature_min=-5, theta_initial=None, scatter_val=1e-10,
                  binary=False, BSGL=False, minCoverFreq=None,
                  maxCoverFreq=None, detector=None, earth_ephem=None,
-                 sun_ephem=None):
+                 sun_ephem=None, injectSource=None):
         """
         Parameters
         label, outdir: str
@@ -746,7 +776,7 @@ class MCMCSearch(BaseSearchClass):
             earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
             detector=self.detector, BSGL=self.BSGL, transient=False,
             minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
-            binary=self.binary)
+            binary=self.binary, injectSources=self.injectSources)
 
     def logp(self, theta_vals, theta_prior, theta_keys, search):
         H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
@@ -1839,7 +1869,8 @@ class MCMCSemiCoherentSearch(MCMCSearch):
                  ntemps=1, log10temperature_min=-5, theta_initial=None,
                  scatter_val=1e-10, detector=None, BSGL=False,
                  minStartTime=None, maxStartTime=None, minCoverFreq=None,
-                 maxCoverFreq=None, earth_ephem=None, sun_ephem=None):
+                 maxCoverFreq=None, earth_ephem=None, sun_ephem=None,
+                 injectSources=None):
         """
 
         """
@@ -1875,7 +1906,8 @@ class MCMCSemiCoherentSearch(MCMCSearch):
             BSGL=self.BSGL, minStartTime=self.minStartTime,
             maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
             maxCoverFreq=self.maxCoverFreq, detector=self.detector,
-            earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem)
+            earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
+            injectSources=self.injectSources)
 
     def logp(self, theta_vals, theta_prior, theta_keys, search):
         H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
@@ -1938,6 +1970,8 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
         fig = None
         axes = None
         nsteps_total = 0
+        self.nsegs = 1
+        self.inititate_search_object()
         for j, ((nburn, nprod), nseg, reset_p0) in enumerate(run_setup):
             if j == 0:
                 p0 = self.generate_initial_p0()
@@ -1950,7 +1984,8 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
                 p0 = sampler.chain[:, :, -1, :]
 
             self.nsegs = nseg
-            self.inititate_search_object()
+            self.search.nsegs = nseg
+            self.search.init_semicoherent_parameters()
             sampler = emcee.PTSampler(
                 self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
                 logpargs=(self.theta_prior, self.theta_keys, self.search),