diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index 1c1d269fbde14e027ea90b8d505577dc5938027f..cf548a97a3b5e61465c829eb49f8e050ffce98ca 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -216,7 +216,8 @@ class MCMCSearch(BaseSearchClass):
     def setup_convergence_testing(
             self, convergence_period=10, convergence_length=10,
             convergence_burnin_fraction=0.25, convergence_threshold_number=10,
-            convergence_threshold=1.2, convergence_prod_threshold=2):
+            convergence_threshold=1.2, convergence_prod_threshold=2,
+            convergence_plot_upper_lim=2):
         """
         If called, convergence testing is used during the MCMC simulation
 
@@ -243,6 +244,8 @@ class MCMCSearch(BaseSearchClass):
             recomend a value of 1.2, 1.1 for strict convergence
         convergence_prod_threshold: float
             the threshold to test the production values with
+        convergence_plot_upper_lim: float
+            the upper limit to use in the diagnostic plot
         """
 
         if convergence_length > convergence_period:
@@ -257,6 +260,7 @@ class MCMCSearch(BaseSearchClass):
         self.convergence_threshold_number = convergence_threshold_number
         self.convergence_threshold = convergence_threshold
         self.convergence_number = 0
+        self.convergence_plot_upper_lim = convergence_plot_upper_lim
 
     def get_convergence_statistic(self, i, sampler):
         s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
@@ -268,15 +272,15 @@ class MCMCSearch(BaseSearchClass):
         B_over_n = between_std**2 / self.convergence_period
         Vhat = ((self.convergence_period-1.)/self.convergence_period * W
                 + B_over_n + B_over_n / float(self.nwalkers))
-        c = Vhat/W
+        c = np.sqrt(Vhat/W)
         self.convergence_diagnostic.append(c)
-        self.convergence_diagnosticx.append(i - self.convergence_period/2)
+        self.convergence_diagnosticx.append(i - self.convergence_length/2)
         return c
 
-    def convergence_test(self, i, sampler, nburn):
+    def burnin_convergence_test(self, i, sampler, nburn):
         if i < self.convergence_burnin_fraction*nburn:
             return False
-        if np.mod(i+1, self.convergence_period) == 0:
+        if np.mod(i+1, self.convergence_period) != 0:
             return False
         c = self.get_convergence_statistic(i, sampler)
         if np.all(c < self.convergence_threshold):
@@ -285,6 +289,12 @@ class MCMCSearch(BaseSearchClass):
             self.convergence_number = 0
         return self.convergence_number > self.convergence_threshold_number
 
+    def prod_convergence_test(self, i, sampler, nburn):
+        testA = i > nburn + self.convergence_length
+        testB = np.mod(i+1, self.convergence_period) == 0
+        if testA and testB:
+            self.get_convergence_statistic(i, sampler)
+
     def check_production_convergence(self, k):
         bools = np.any(
             np.array(self.convergence_diagnostic)[k:, :]
@@ -296,26 +306,23 @@ class MCMCSearch(BaseSearchClass):
 
     def run_sampler(self, sampler, p0, nprod=0, nburn=0):
         if hasattr(self, 'convergence_period'):
-            converged = False
             logging.info('Running {} burn-in steps with convergence testing'
                          .format(nburn))
             iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
             for i, output in enumerate(iterator):
-                if converged:
+                if self.burnin_convergence_test(i, sampler, nburn):
                     logging.info(
                         'Converged at {} before max number {} of steps reached'
                         .format(i, nburn))
                     self.convergence_idx = i
                     break
-                else:
-                    converged = self.convergence_test(i, sampler, nburn)
             iterator.close()
             logging.info('Running {} production steps'.format(nprod))
             j = nburn
             k = len(self.convergence_diagnostic)
             for result in tqdm(sampler.sample(output[0], iterations=nprod),
                                total=nprod):
-                self.get_convergence_statistic(j, sampler)
+                self.prod_convergence_test(j, sampler, nburn)
                 j += 1
             self.check_production_convergence(k)
             return sampler
@@ -365,7 +372,7 @@ class MCMCSearch(BaseSearchClass):
                                               **kwargs)
                 fig.tight_layout()
                 fig.savefig('{}/{}_init_{}_walkers.png'.format(
-                    self.outdir, self.label, j), dpi=200)
+                    self.outdir, self.label, j), dpi=400)
 
             p0 = self.get_new_p0(sampler)
             p0 = self.apply_corrections_to_p0(p0)
@@ -696,9 +703,12 @@ class MCMCSearch(BaseSearchClass):
                         ax = axes[i].twinx()
                         c_x = np.array(self.convergence_diagnosticx)
                         c_y = np.array(self.convergence_diagnostic)
-                        ax.plot(c_x, c_y[:, i], '-b')
+                        break_idx = np.argmin(np.abs(c_x - burnin_idx))
+                        ax.plot(c_x[:break_idx], c_y[:break_idx, i], '-b')
+                        ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-b')
+                        ax.set_ylabel('PSRF')
                         ax.ticklabel_format(useOffset=False)
-                        ax.set_ylim(1, 5)
+                        ax.set_ylim(1, self.convergence_plot_upper_lim)
             else:
                 axes[0].ticklabel_format(useOffset=False, axis='y')
                 cs = chain[:, :, temp].T