diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index 22ed2dcc18f14ac115383b16e5f697df203fc299..d2759dbc1a68a9760bfa3a06107078ce502e1a61 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -214,63 +214,57 @@ class MCMCSearch(BaseSearchClass):
             pass
         return sampler
 
-    #def run_sampler(self, sampler, ns, p0):
-    #    convergence_period = 200
-    #    convergence_diagnostic = []
-    #    convergence_diagnosticx = []
-    #    for i, result in enumerate(tqdm(
-    #            sampler.sample(p0, iterations=ns), total=ns)):
-    #        if np.mod(i+1, convergence_period) == 0:
-    #            s = sampler.chain[0, :, i-convergence_period+1:i+1, :]
-    #            score_per_parameter = []
-    #            for j in range(self.ndim):
-    #                scores = []
-    #                for k in range(self.nwalkers):
-    #                    out = pymc3.geweke(
-    #                        s[k, :, j].reshape((convergence_period)),
-    #                        intervals=2, first=0.4, last=0.4)
-    #                    scores.append(out[0][1])
-    #                score_per_parameter.append(np.median(scores))
-    #            convergence_diagnostic.append(score_per_parameter)
-    #            convergence_diagnosticx.append(i - convergence_period/2)
-    #    self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic))
-    #    self.convergence_diagnosticx = convergence_diagnosticx
-    #    return sampler
-
-    #def run_sampler(self, sampler, ns, p0):
-    #    convergence_period = 200
-    #    convergence_diagnostic = []
-    #    convergence_diagnosticx = []
-    #    for i, result in enumerate(tqdm(
-    #            sampler.sample(p0, iterations=ns), total=ns)):
-    #        if np.mod(i+1, convergence_period) == 0:
-    #            s = sampler.chain[0, :, i-convergence_period+1:i+1, :]
-    #            mean_per_chain = np.mean(s, axis=1)
-    #            std_per_chain = np.std(s, axis=1)
-    #            mean = np.mean(mean_per_chain, axis=0)
-    #            B = convergence_period * np.sum((mean_per_chain - mean)**2, axis=0) / (self.nwalkers - 1)
-    #            W = np.sum(std_per_chain**2, axis=0) / self.nwalkers
-    #            print B, W
-    #            convergence_diagnostic.append(W/B)
-    #            convergence_diagnosticx.append(i - convergence_period/2)
-    #    self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic))
-    #    self.convergence_diagnosticx = convergence_diagnosticx
-    #    return sampler
-
-    def run_sampler(self, sampler, ns, p0):
-        convergence_period = 200
-        convergence_diagnostic = []
-        convergence_diagnosticx = []
-        for i, result in enumerate(tqdm(
-                sampler.sample(p0, iterations=ns), total=ns)):
-            if np.mod(i+1, convergence_period) == 0:
-                s = sampler.chain[0, :, i-convergence_period+1:i+1, :]
-                Z = (s - np.mean(s, axis=(0, 1)))/np.std(s, axis=(0, 1))
-                convergence_diagnostic.append(np.mean(Z, axis=(0, 1)))
-                convergence_diagnosticx.append(i - convergence_period/2)
-        self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic))
-        self.convergence_diagnosticx = convergence_diagnosticx
-        return sampler
+    def setup_convergence_testing(
+            self, convergence_period=10, convergence_length=10,
+            convergence_burnin_fraction=0.25, convergence_threshold_number=5,
+            convergence_threshold=1.2):
+
+        if convergence_length > convergence_period:
+            raise ValueError('convergence_length must be < convergence_period')
+        logging.info('Setting up convergence testing')
+        self.convergence_length = convergence_length
+        self.convergence_period = convergence_period
+        self.convergence_burnin_fraction = convergence_burnin_fraction
+        self.convergence_diagnostic = []
+        self.convergence_diagnosticx = []
+        self.convergence_threshold_number = convergence_threshold_number
+        self.convergence_threshold = convergence_threshold
+        self.convergence_number = 0
+
+    def convergence_test(self, i, sampler, nburn):
+        if i < self.convergence_burnin_fraction*nburn:
+            return False
+        if np.mod(i+1, self.convergence_period) == 0:
+            return False
+        s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
+        within_std = np.mean(np.var(s, axis=1), axis=0)
+        per_walker_mean = np.mean(s, axis=1)
+        mean = np.mean(per_walker_mean, axis=0)
+        between_std = np.sqrt(np.mean((per_walker_mean-mean)**2, axis=0))
+        W = within_std
+        B_over_n = between_std**2 / self.convergence_period
+        Vhat = ((self.convergence_period-1.)/self.convergence_period * W
+                + B_over_n + B_over_n / float(self.nwalkers))
+        c = Vhat/W
+        self.convergence_diagnostic.append(c)
+        self.convergence_diagnosticx.append(i - self.convergence_period/2)
+        if np.all(c < self.convergence_threshold):
+            self.convergence_number += 1
+
+        return self.convergence_number > self.convergence_threshold_number
+
+    def run_sampler(self, sampler, p0, nprod=0, nburn=0):
+        if hasattr(self, 'convergence_period'):
+            for i, result in enumerate(tqdm(
+                    sampler.sample(p0, iterations=nburn+nprod),
+                    total=nburn+nprod)):
+                converged = self.convergence_test(i, sampler, nburn)
+            return sampler
+        else:
+            for result in tqdm(sampler.sample(p0, iterations=nburn+nprod),
+                               total=nburn+nprod):
+                pass
+            return sampler
 
     def run(self, proposal_scale_factor=2, create_plots=True, **kwargs):
 
@@ -300,7 +294,7 @@ class MCMCSearch(BaseSearchClass):
         for j, n in enumerate(self.nsteps[:-2]):
             logging.info('Running {}/{} initialisation with {} steps'.format(
                 j, ninit_steps, n))
-            sampler = self.run_sampler(sampler, n, p0)
+            sampler = self.run_sampler(sampler, p0, nburn=n)
             logging.info("Mean acceptance fraction: {}"
                          .format(np.mean(sampler.acceptance_fraction, axis=1)))
             if self.ntemps > 1:
@@ -326,7 +320,7 @@ class MCMCSearch(BaseSearchClass):
         nprod = self.nsteps[-1]
         logging.info('Running final burn and prod with {} steps'.format(
             nburn+nprod))
-        sampler = self.run_sampler(sampler, nburn+nprod, p0)
+        sampler = self.run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
         logging.info("Mean acceptance fraction: {}"
                      .format(np.mean(sampler.acceptance_fraction, axis=1)))
         if self.ntemps > 1:
@@ -636,8 +630,11 @@ class MCMCSearch(BaseSearchClass):
 
                     if hasattr(self, 'convergence_diagnostic'):
                         ax = axes[i].twinx()
-                        ax.plot(self.convergence_diagnosticx,
-                                self.convergence_diagnostic[:, i], '-b')
+                        c_x = np.array(self.convergence_diagnosticx)
+                        c_y = np.array(self.convergence_diagnostic)
+                        ax.plot(c_x, c_y[:, i], '-b')
+                        ax.ticklabel_format(useOffset=False)
+                        ax.set_ylim(1, 5)
             else:
                 axes[0].ticklabel_format(useOffset=False, axis='y')
                 cs = chain[:, :, temp].T
@@ -1656,7 +1653,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             logging.info(('Running {}/{} with {} steps and {} nsegs '
                           '(Tcoh={:1.2f} days)').format(
                 j+1, len(run_setup), (nburn, nprod), nseg, Tcoh))
-            sampler = self.run_sampler(sampler, nburn+nprod, p0)
+            sampler = self.run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
             logging.info("Mean acceptance fraction: {}"
                          .format(np.mean(sampler.acceptance_fraction, axis=1)))
             if self.ntemps > 1:
diff --git a/tests.py b/tests.py
index bcf7ec8ec4f51d264a16d79338f4c81008789e0d..30cc5e798e337fd1af2eff25ae956412b2951780 100644
--- a/tests.py
+++ b/tests.py
@@ -192,7 +192,7 @@ class TestMCMCSearch(Test):
     label = "Test"
 
     def test_fully_coherent(self):
-        h0 = 1e-27
+        h0 = 1e-24
         sqrtSX = 1e-22
         F0 = 30
         F1 = -1e-10
@@ -214,7 +214,7 @@ class TestMCMCSearch(Test):
         Writer.make_data()
         predicted_FS = Writer.predict_fstat()
 
-        theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-8*F0)},
+        theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-7*F0)},
                  'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-3*F1)},
                  'F2': F2, 'Alpha': Alpha, 'Delta': Delta}
 
@@ -222,8 +222,9 @@ class TestMCMCSearch(Test):
             label=self.label, outdir=outdir, theta_prior=theta, tref=tref,
             sftfilepath='{}/*{}*sft'.format(Writer.outdir, Writer.label),
             minStartTime=minStartTime, maxStartTime=maxStartTime,
-            nsteps=[500, 100], nwalkers=100, ntemps=2)
-        search.run()
+            nsteps=[500, 100], nwalkers=100, ntemps=2, log10temperature_min=-1)
+        search.setup_convergence_testing()
+        search.run(create_plots=True)
         search.plot_corner(add_prior=True)
         _, FS = search.get_max_twoF()