diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index d2759dbc1a68a9760bfa3a06107078ce502e1a61..9d9b9ad02ae8748a40c7f4820156abd6df5eb509 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -216,8 +216,35 @@ class MCMCSearch(BaseSearchClass):
 
     def setup_convergence_testing(
             self, convergence_period=10, convergence_length=10,
-            convergence_burnin_fraction=0.25, convergence_threshold_number=5,
-            convergence_threshold=1.2):
+            convergence_burnin_fraction=0.25, convergence_threshold_number=10,
+            convergence_threshold=1.2, convergence_prod_threshold=2):
+        """
+        If called, convergence testing is used during the MCMC simulation
+
+        This uses the Gelmanr-Rubin statistic based on the ratio of between and
+        within walkers variance. The original statistic was developed for
+        multiple (independent) MCMC simulations, in this context we simply use
+        the walkers
+
+        Parameters
+        ----------
+        convergence_period: int
+            period (in number of steps) at which to test convergence
+        convergence_length: int
+            number of steps to use in testing convergence - this should be
+            large enough to measure the variance, but if it is too long
+            this will result in incorect early convergence tests
+        convergence_burnin_fraction: float [0, 1]
+            the fraction of the burn-in period after which to start testing
+        convergence_threshold_number: int
+            the number of consecutive times where the test passes after which
+            to break the burn-in and go to production
+        convergence_threshold: float
+            the threshold to use in diagnosing convergence. Gelman & Rubin
+            recomend a value of 1.2, 1.1 for strict convergence
+        convergence_prod_threshold: float
+            the threshold to test the production values with
+        """
 
         if convergence_length > convergence_period:
             raise ValueError('convergence_length must be < convergence_period')
@@ -225,17 +252,14 @@ class MCMCSearch(BaseSearchClass):
         self.convergence_length = convergence_length
         self.convergence_period = convergence_period
         self.convergence_burnin_fraction = convergence_burnin_fraction
+        self.convergence_prod_threshold = convergence_prod_threshold
         self.convergence_diagnostic = []
         self.convergence_diagnosticx = []
         self.convergence_threshold_number = convergence_threshold_number
         self.convergence_threshold = convergence_threshold
         self.convergence_number = 0
 
-    def convergence_test(self, i, sampler, nburn):
-        if i < self.convergence_burnin_fraction*nburn:
-            return False
-        if np.mod(i+1, self.convergence_period) == 0:
-            return False
+    def get_convergence_statistic(self, i, sampler):
         s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
         within_std = np.mean(np.var(s, axis=1), axis=0)
         per_walker_mean = np.mean(s, axis=1)
@@ -248,17 +272,53 @@ class MCMCSearch(BaseSearchClass):
         c = Vhat/W
         self.convergence_diagnostic.append(c)
         self.convergence_diagnosticx.append(i - self.convergence_period/2)
+        return c
+
+    def convergence_test(self, i, sampler, nburn):
+        if i < self.convergence_burnin_fraction*nburn:
+            return False
+        if np.mod(i+1, self.convergence_period) == 0:
+            return False
+        c = self.get_convergence_statistic(i, sampler)
         if np.all(c < self.convergence_threshold):
             self.convergence_number += 1
-
+        else:
+            self.convergence_number = 0
         return self.convergence_number > self.convergence_threshold_number
 
+    def check_production_convergence(self, k):
+        bools = np.any(
+            np.array(self.convergence_diagnostic)[k:, :]
+            > self.convergence_prod_threshold, axis=1)
+        if np.any(bools):
+            logging.warning(
+                '{} convergence tests in the production run of {} failed'
+                .format(np.sum(bools), len(bools)))
+
     def run_sampler(self, sampler, p0, nprod=0, nburn=0):
         if hasattr(self, 'convergence_period'):
-            for i, result in enumerate(tqdm(
-                    sampler.sample(p0, iterations=nburn+nprod),
-                    total=nburn+nprod)):
-                converged = self.convergence_test(i, sampler, nburn)
+            converged = False
+            logging.info('Running {} burn-in steps with convergence testing'
+                         .format(nburn))
+            iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
+            for i, output in enumerate(iterator):
+                if converged:
+                    logging.info(
+                        'Converged at {} before max number {} of steps reached'
+                        .format(i, nburn))
+                    self.convergence_idx = i
+                    break
+                else:
+                    converged = self.convergence_test(i, sampler, nburn)
+            iterator.close()
+            logging.info('Running {} production steps'.format(nprod))
+            j = nburn
+            k = len(self.convergence_diagnostic)
+            for result in tqdm(sampler.sample(output[0], iterations=nprod),
+                               total=nprod):
+                self.get_convergence_statistic(j, sampler)
+                j += 1
+            self.check_production_convergence(k)
             return sampler
         else:
             for result in tqdm(sampler.sample(p0, iterations=nburn+nprod),
@@ -329,7 +389,7 @@ class MCMCSearch(BaseSearchClass):
 
         if create_plots:
             fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
-                                          burnin_idx=nburn, **kwargs)
+                                          nprod=nprod, **kwargs)
             fig.tight_layout()
             fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
                         dpi=200)
@@ -572,7 +632,7 @@ class MCMCSearch(BaseSearchClass):
             raise ValueError("dist_type {} unknown".format(dist_type))
 
     def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
-                     lw=0.1, burnin_idx=None, add_det_stat_burnin=False,
+                     lw=0.1, nprod=None, add_det_stat_burnin=False,
                      fig=None, axes=None, xoffset=0, plot_det_stat=True,
                      context='classic', subtractions=None, labelpad=0.05):
         """ Plot all the chains from a sampler """
@@ -608,13 +668,18 @@ class MCMCSearch(BaseSearchClass):
                                for i in range(2, ndim+1)]
 
             idxs = np.arange(chain.shape[1])
+            burnin_idx = chain.shape[1] - nprod
+            if hasattr(self, 'convergence_idx'):
+                convergence_idx = self.convergence_idx
+            else:
+                convergence_idx = burnin_idx
             if ndim > 1:
                 for i in range(ndim):
                     axes[i].ticklabel_format(useOffset=False, axis='y')
                     cs = chain[:, :, i].T
-                    if burnin_idx:
-                        axes[i].plot(xoffset+idxs[:burnin_idx],
-                                     cs[:burnin_idx]-subtractions[i],
+                    if burnin_idx > 0:
+                        axes[i].plot(xoffset+idxs[:convergence_idx],
+                                     cs[:convergence_idx]-subtractions[i],
                                      color="r", alpha=alpha,
                                      lw=lw)
                     axes[i].plot(xoffset+idxs[burnin_idx:],
@@ -1665,7 +1730,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             if create_plots:
                 fig, axes = self.plot_walkers(
                     sampler, symbols=self.theta_symbols, fig=fig, axes=axes,
-                    burnin_idx=nburn, xoffset=nsteps_total, **kwargs)
+                    nprod=nprod, xoffset=nsteps_total, **kwargs)
                 for ax in axes[:self.ndim]:
                     ax.axvline(nsteps_total, color='k', ls='--', lw=0.25)
 
diff --git a/tests.py b/tests.py
index 30cc5e798e337fd1af2eff25ae956412b2951780..7d639413f41cea9abdd45682940862c469fceb9b 100644
--- a/tests.py
+++ b/tests.py
@@ -214,15 +214,15 @@ class TestMCMCSearch(Test):
         Writer.make_data()
         predicted_FS = Writer.predict_fstat()
 
-        theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-7*F0)},
-                 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-3*F1)},
+        theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)},
+                 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-9*F1)},
                  'F2': F2, 'Alpha': Alpha, 'Delta': Delta}
 
         search = pyfstat.MCMCSearch(
             label=self.label, outdir=outdir, theta_prior=theta, tref=tref,
             sftfilepath='{}/*{}*sft'.format(Writer.outdir, Writer.label),
             minStartTime=minStartTime, maxStartTime=maxStartTime,
-            nsteps=[500, 100], nwalkers=100, ntemps=2, log10temperature_min=-1)
+            nsteps=[100, 100], nwalkers=100, ntemps=2, log10temperature_min=-1)
         search.setup_convergence_testing()
         search.run(create_plots=True)
         search.plot_corner(add_prior=True)