diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index 67a21080a39ae1a3fe75a3bbeb4ce45ba9f57a31..1edffa4623110deae488653bd932dff45c75286e 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -41,11 +41,10 @@ class MCMCSearch(core.BaseSearchClass):
     detectors: str
         Two character reference to the detectors to use, specify None for no
         contraint and comma separate for multiple references.
-    nsteps: list (m,)
-        List specifying the number of steps to take, the last two entries
-        give the nburn and nprod of the 'production' run, all entries
-        before are for iterative initialisation steps (usually just one)
-        e.g. [1000, 1000, 500].
+    nsteps: list (2,)
+        Number of burn-in and production steps to take, [nburn, nprod]. See
+        `pyfstat.MCMCSearch.setup_initialisation()` for details on adding
+        initialisation steps.
     nwalkers, ntemps: int,
         The number of walkers and temperates to use in the parallel
         tempered PTSampler.
@@ -103,7 +102,7 @@ class MCMCSearch(core.BaseSearchClass):
                  maxStartTime, sftfilepattern=None, detectors=None,
                  nsteps=[100, 100], nwalkers=100, ntemps=1,
                  log10temperature_min=-5, theta_initial=None,
-                 scatter_val=1e-10, rhohatmax=1000, binary=False, BSGL=False,
+                 rhohatmax=1000, binary=False, BSGL=False,
                  SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
                  injectSources=None, assumeSqrtSX=None):
 
@@ -137,7 +136,6 @@ class MCMCSearch(core.BaseSearchClass):
     def _log_input(self):
         logging.info('theta_prior = {}'.format(self.theta_prior))
         logging.info('nwalkers={}'.format(self.nwalkers))
-        logging.info('scatter_val = {}'.format(self.scatter_val))
         logging.info('nsteps = {}'.format(self.nsteps))
         logging.info('ntemps = {}'.format(self.ntemps))
         logging.info('log10temperature_min = {}'.format(
@@ -278,6 +276,30 @@ class MCMCSearch(core.BaseSearchClass):
         else:
             raise ValueError('test_type {} not understood'.format(test_type))
 
+    def setup_initialisation(self, nburn0, scatter_val=1e-10):
+        """ Add an initialisation step to the MCMC run
+
+        If called prior to `run()`, adds an intial step in which the MCMC
+        simulation is run for `nburn0` steps. After this, the MCMC simulation
+        continues in the usual manner (i.e. for nburn and nprod steps), but the
+        walkers are reset scattered around the maximum likelihood position
+        of the initialisation step.
+
+        Parameters
+        ----------
+        nburn0: int
+            Number of initialisation steps to take
+        scatter_val: float
+            Relative number to scatter walkers around the maximum likelihood
+            position after the initialisation step
+
+        """
+
+        logging.info('Setting up initialisation with nburn0={}, scatter_val={}'
+                     .format(nburn0, scatter_val))
+        self.nsteps = [nburn0] + self.nsteps
+        self.scatter_val = scatter_val
+
     def _test_autocorr_convergence(self, i, sampler, test=True, n_cut=5):
         try:
             acors = np.zeros((self.ntemps, self.ndim))
@@ -451,6 +473,7 @@ class MCMCSearch(core.BaseSearchClass):
         p0 = self._apply_corrections_to_p0(p0)
         self._check_initial_points(p0)
 
+        # Run initialisation steps if required
         ninit_steps = len(self.nsteps) - 2
         for j, n in enumerate(self.nsteps[:-2]):
             logging.info('Running {}/{} initialisation with {} steps'.format(
@@ -1139,7 +1162,7 @@ class MCMCSearch(core.BaseSearchClass):
     def _get_data_dictionary_to_save(self):
         d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
                  ntemps=self.ntemps, theta_keys=self.theta_keys,
-                 theta_prior=self.theta_prior, scatter_val=self.scatter_val,
+                 theta_prior=self.theta_prior,
                  log10temperature_min=self.log10temperature_min,
                  BSGL=self.BSGL)
         return d
@@ -1549,7 +1572,7 @@ class MCMCGlitchSearch(MCMCSearch):
                  maxStartTime, sftfilepattern=None, detectors=None,
                  nsteps=[100, 100], nwalkers=100, ntemps=1,
                  log10temperature_min=-5, theta_initial=None,
-                 scatter_val=1e-10, rhohatmax=1000, binary=False, BSGL=False,
+                 rhohatmax=1000, binary=False, BSGL=False,
                  SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
                  injectSources=None, assumeSqrtSX=None,
                  dtglitchmin=1*86400, theta0_idx=0, nglitch=1):
@@ -1678,7 +1701,7 @@ class MCMCGlitchSearch(MCMCSearch):
     def _get_data_dictionary_to_save(self):
         d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
                  ntemps=self.ntemps, theta_keys=self.theta_keys,
-                 theta_prior=self.theta_prior, scatter_val=self.scatter_val,
+                 theta_prior=self.theta_prior,
                  log10temperature_min=self.log10temperature_min,
                  theta0_idx=self.theta0_idx, BSGL=self.BSGL)
         return d
@@ -1758,7 +1781,7 @@ class MCMCSemiCoherentSearch(MCMCSearch):
                  maxStartTime, sftfilepattern=None, detectors=None,
                  nsteps=[100, 100], nwalkers=100, ntemps=1,
                  log10temperature_min=-5, theta_initial=None,
-                 scatter_val=1e-10, rhohatmax=1000, binary=False, BSGL=False,
+                 rhohatmax=1000, binary=False, BSGL=False,
                  SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
                  injectSources=None, assumeSqrtSX=None,
                  nsegs=None):
@@ -1792,7 +1815,7 @@ class MCMCSemiCoherentSearch(MCMCSearch):
     def _get_data_dictionary_to_save(self):
         d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
                  ntemps=self.ntemps, theta_keys=self.theta_keys,
-                 theta_prior=self.theta_prior, scatter_val=self.scatter_val,
+                 theta_prior=self.theta_prior,
                  log10temperature_min=self.log10temperature_min,
                  BSGL=self.BSGL, nsegs=self.nsegs)
         return d
@@ -1830,7 +1853,6 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
     def _get_data_dictionary_to_save(self):
         d = dict(nwalkers=self.nwalkers, ntemps=self.ntemps,
                  theta_keys=self.theta_keys, theta_prior=self.theta_prior,
-                 scatter_val=self.scatter_val,
                  log10temperature_min=self.log10temperature_min,
                  BSGL=self.BSGL, run_setup=self.run_setup)
         return d