diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index acb466b64bac4aef587a6104a1b30ce1adfbc43d..22ed2dcc18f14ac115383b16e5f697df203fc299 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -10,6 +10,7 @@ import numpy as np
 import matplotlib
 import matplotlib.pyplot as plt
 import emcee
+import pymc3
 import corner
 import dill as pickle
 
@@ -208,11 +209,69 @@ class MCMCSearch(BaseSearchClass):
 
         return p0
 
-    def run_sampler_with_progress_bar(self, sampler, ns, p0):
+    def OLD_run_sampler_with_progress_bar(self, sampler, ns, p0):
         for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
             pass
         return sampler
 
+    #def run_sampler(self, sampler, ns, p0):
+    #    convergence_period = 200
+    #    convergence_diagnostic = []
+    #    convergence_diagnosticx = []
+    #    for i, result in enumerate(tqdm(
+    #            sampler.sample(p0, iterations=ns), total=ns)):
+    #        if np.mod(i+1, convergence_period) == 0:
+    #            s = sampler.chain[0, :, i-convergence_period+1:i+1, :]
+    #            score_per_parameter = []
+    #            for j in range(self.ndim):
+    #                scores = []
+    #                for k in range(self.nwalkers):
+    #                    out = pymc3.geweke(
+    #                        s[k, :, j].reshape((convergence_period)),
+    #                        intervals=2, first=0.4, last=0.4)
+    #                    scores.append(out[0][1])
+    #                score_per_parameter.append(np.median(scores))
+    #            convergence_diagnostic.append(score_per_parameter)
+    #            convergence_diagnosticx.append(i - convergence_period/2)
+    #    self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic))
+    #    self.convergence_diagnosticx = convergence_diagnosticx
+    #    return sampler
+
+    #def run_sampler(self, sampler, ns, p0):
+    #    convergence_period = 200
+    #    convergence_diagnostic = []
+    #    convergence_diagnosticx = []
+    #    for i, result in enumerate(tqdm(
+    #            sampler.sample(p0, iterations=ns), total=ns)):
+    #        if np.mod(i+1, convergence_period) == 0:
+    #            s = sampler.chain[0, :, i-convergence_period+1:i+1, :]
+    #            mean_per_chain = np.mean(s, axis=1)
+    #            std_per_chain = np.std(s, axis=1)
+    #            mean = np.mean(mean_per_chain, axis=0)
+    #            B = convergence_period * np.sum((mean_per_chain - mean)**2, axis=0) / (self.nwalkers - 1)
+    #            W = np.sum(std_per_chain**2, axis=0) / self.nwalkers
+    #            print B, W
+    #            convergence_diagnostic.append(W/B)
+    #            convergence_diagnosticx.append(i - convergence_period/2)
+    #    self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic))
+    #    self.convergence_diagnosticx = convergence_diagnosticx
+    #    return sampler
+
+    def run_sampler(self, sampler, ns, p0):
+        convergence_period = 200
+        convergence_diagnostic = []
+        convergence_diagnosticx = []
+        for i, result in enumerate(tqdm(
+                sampler.sample(p0, iterations=ns), total=ns)):
+            if np.mod(i+1, convergence_period) == 0:
+                s = sampler.chain[0, :, i-convergence_period+1:i+1, :]
+                Z = (s - np.mean(s, axis=(0, 1)))/np.std(s, axis=(0, 1))
+                convergence_diagnostic.append(np.mean(Z, axis=(0, 1)))
+                convergence_diagnosticx.append(i - convergence_period/2)
+        self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic))
+        self.convergence_diagnosticx = convergence_diagnosticx
+        return sampler
+
     def run(self, proposal_scale_factor=2, create_plots=True, **kwargs):
 
         self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
@@ -241,7 +300,7 @@ class MCMCSearch(BaseSearchClass):
         for j, n in enumerate(self.nsteps[:-2]):
             logging.info('Running {}/{} initialisation with {} steps'.format(
                 j, ninit_steps, n))
-            sampler = self.run_sampler_with_progress_bar(sampler, n, p0)
+            sampler = self.run_sampler(sampler, n, p0)
             logging.info("Mean acceptance fraction: {}"
                          .format(np.mean(sampler.acceptance_fraction, axis=1)))
             if self.ntemps > 1:
@@ -267,7 +326,7 @@ class MCMCSearch(BaseSearchClass):
         nprod = self.nsteps[-1]
         logging.info('Running final burn and prod with {} steps'.format(
             nburn+nprod))
-        sampler = self.run_sampler_with_progress_bar(sampler, nburn+nprod, p0)
+        sampler = self.run_sampler(sampler, nburn+nprod, p0)
         logging.info("Mean acceptance fraction: {}"
                      .format(np.mean(sampler.acceptance_fraction, axis=1)))
         if self.ntemps > 1:
@@ -575,6 +634,10 @@ class MCMCSearch(BaseSearchClass):
                                 symbols[i]+'$-$'+symbols[i]+'$_0$',
                                 labelpad=labelpad)
 
+                    if hasattr(self, 'convergence_diagnostic'):
+                        ax = axes[i].twinx()
+                        ax.plot(self.convergence_diagnosticx,
+                                self.convergence_diagnostic[:, i], '-b')
             else:
                 axes[0].ticklabel_format(useOffset=False, axis='y')
                 cs = chain[:, :, temp].T
@@ -623,7 +686,7 @@ class MCMCSearch(BaseSearchClass):
                     axes[-1].set_xlim(minv-0.1*Range, maxv+0.1*Range)
 
                 xfmt = matplotlib.ticker.ScalarFormatter()
-                xfmt.set_powerlimits((-4, 4)) 
+                xfmt.set_powerlimits((-4, 4))
                 axes[-1].xaxis.set_major_formatter(xfmt)
 
             axes[-2].set_xlabel(r'$\textrm{Number of steps}$', labelpad=0.2)
@@ -1593,8 +1656,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             logging.info(('Running {}/{} with {} steps and {} nsegs '
                           '(Tcoh={:1.2f} days)').format(
                 j+1, len(run_setup), (nburn, nprod), nseg, Tcoh))
-            sampler = self.run_sampler_with_progress_bar(
-                sampler, nburn+nprod, p0)
+            sampler = self.run_sampler(sampler, nburn+nprod, p0)
             logging.info("Mean acceptance fraction: {}"
                          .format(np.mean(sampler.acceptance_fraction, axis=1)))
             if self.ntemps > 1:
diff --git a/tests.py b/tests.py
index b6e8fa09ee7fc0e3f50cf6c4b140690536f23fcb..bcf7ec8ec4f51d264a16d79338f4c81008789e0d 100644
--- a/tests.py
+++ b/tests.py
@@ -192,7 +192,7 @@ class TestMCMCSearch(Test):
     label = "Test"
 
     def test_fully_coherent(self):
-        h0 = 1e-24
+        h0 = 1e-27
         sqrtSX = 1e-22
         F0 = 30
         F1 = -1e-10
@@ -203,27 +203,26 @@ class TestMCMCSearch(Test):
         Alpha = 5e-3
         Delta = 1.2
         tref = minStartTime
-        dtglitch = None
         delta_F0 = 0
         Writer = pyfstat.Writer(F0=F0, F1=F1, F2=F2, label=self.label,
                                 h0=h0, sqrtSX=sqrtSX,
                                 outdir=outdir, tstart=minStartTime,
                                 Alpha=Alpha, Delta=Delta, tref=tref,
-                                duration=duration, dtglitch=dtglitch,
+                                duration=duration,
                                 delta_F0=delta_F0, Band=4)
 
         Writer.make_data()
         predicted_FS = Writer.predict_fstat()
 
-        theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)},
-                 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-9*F1)},
+        theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-8*F0)},
+                 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-3*F1)},
                  'F2': F2, 'Alpha': Alpha, 'Delta': Delta}
 
         search = pyfstat.MCMCSearch(
             label=self.label, outdir=outdir, theta_prior=theta, tref=tref,
             sftfilepath='{}/*{}*sft'.format(Writer.outdir, Writer.label),
             minStartTime=minStartTime, maxStartTime=maxStartTime,
-            nsteps=[100, 100], nwalkers=100, ntemps=1)
+            nsteps=[500, 100], nwalkers=100, ntemps=2)
         search.run()
         search.plot_corner(add_prior=True)
         _, FS = search.get_max_twoF()