diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index 06cdfc28e5556fe2033186e6422d0b59482ba98c..fefa8e7b841b1fdb295d40b41adca5fd0150716e 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -239,61 +239,63 @@ class MCMCSearch(core.BaseSearchClass):
             pass
         return sampler
 
-    def setup_convergence_testing(
-            self, convergence_period=10, convergence_length=10,
-            convergence_burnin_fraction=0.25, convergence_threshold_number=10,
-            convergence_threshold=1.2, convergence_prod_threshold=2,
-            convergence_plot_upper_lim=2, convergence_early_stopping=True):
+    def setup_burnin_convergence_testing(
+            self, n=10, test_type='autocorr', windowed=False, **kwargs):
         """
         If called, convergence testing is used during the MCMC simulation
 
-        This uses the Gelmanr-Rubin statistic based on the ratio of between and
-        within walkers variance. The original statistic was developed for
-        multiple (independent) MCMC simulations, in this context we simply use
-        the walkers
-
         Parameters
         ----------
-        convergence_period: int
-            period (in number of steps) at which to test convergence
-        convergence_length: int
-            number of steps to use in testing convergence - this should be
-            large enough to measure the variance, but if it is too long
-            this will result in incorect early convergence tests
-        convergence_burnin_fraction: float [0, 1]
-            the fraction of the burn-in period after which to start testing
-        convergence_threshold_number: int
-            the number of consecutive times where the test passes after which
-            to break the burn-in and go to production
-        convergence_threshold: float
-            the threshold to use in diagnosing convergence. Gelman & Rubin
-            recomend a value of 1.2, 1.1 for strict convergence
-        convergence_prod_threshold: float
-            the threshold to test the production values with
-        convergence_plot_upper_lim: float
-            the upper limit to use in the diagnostic plot
-        convergence_early_stopping: bool
-            if true, stop the burnin early if convergence is reached
+        n: int
+            Number of steps after which to test convergence
+        test_type: str ['autocorr', 'GR']
+            If 'autocorr' use the exponential autocorrelation time (kwargs
+            passed to `get_autocorr_convergence`). If 'GR' use the Gelman-Rubin
+            statistic (kwargs passed to `get_GR_convergence`)
+        windowed: bool
+            If True, only calculate the convergence test in a window of length
+            `n`
         """
-
-        if convergence_length > convergence_period:
-            raise ValueError('convergence_length must be < convergence_period')
         logging.info('Setting up convergence testing')
-        self.convergence_length = convergence_length
-        self.convergence_period = convergence_period
-        self.convergence_burnin_fraction = convergence_burnin_fraction
-        self.convergence_prod_threshold = convergence_prod_threshold
+        self.convergence_n = n
+        self.convergence_windowed = windowed
+        self.convergence_test_type = test_type
+        self.convergence_kwargs = kwargs
         self.convergence_diagnostic = []
         self.convergence_diagnosticx = []
-        self.convergence_threshold_number = convergence_threshold_number
-        self.convergence_threshold = convergence_threshold
-        self.convergence_number = 0
-        self.convergence_plot_upper_lim = convergence_plot_upper_lim
-        self.convergence_early_stopping = convergence_early_stopping
-
-    def _get_convergence_statistic(self, i, sampler):
-        s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
-        N = float(self.convergence_length)
+        if test_type in ['autocorr']:
+            self._get_convergence_test = self.test_autocorr_convergence
+        elif test_type in ['GR']:
+            self._get_convergence_test= self.test_GR_convergence
+        else:
+            raise ValueError('test_type {} not understood'.format(test_type))
+
+    def test_autocorr_convergence(self, i, sampler, test=True, n_cut=5):
+        try:
+            acors = np.zeros((self.ntemps, self.ndim))
+            for temp in range(self.ntemps):
+                if self.convergence_windowed:
+                    j = i-self.convergence_n
+                else:
+                    j = 0
+                x = np.mean(sampler.chain[temp, :, j:i, :], axis=0)
+                acors[temp, :] = emcee.autocorr.exponential_time(x)
+            c = np.max(acors, axis=0)
+        except emcee.autocorr.AutocorrError:
+            c = np.zeros(self.ndim) + np.nan
+
+        self.convergence_diagnosticx.append(i - self.convergence_n/2.)
+        self.convergence_diagnostic.append(list(c))
+
+        if test:
+            return i > n_cut * np.max(c)
+
+    def test_GR_convergence(self, i, sampler, test=True, R=1.1):
+        if self.convergence_windowed:
+            s = sampler.chain[0, :, i-self.convergence_n+1:i+1, :]
+        else:
+            s = sampler.chain[0, :, :i+1, :]
+        N = float(self.convergence_n)
         M = float(self.nwalkers)
         W = np.mean(np.var(s, axis=1), axis=0)
         per_walker_mean = np.mean(s, axis=1)
@@ -302,58 +304,45 @@ class MCMCSearch(core.BaseSearchClass):
         Vhat = (N-1)/N * W + (M+1)/(M*N) * B
         c = np.sqrt(Vhat/W)
         self.convergence_diagnostic.append(c)
-        self.convergence_diagnosticx.append(i - self.convergence_length/2)
-        return c
+        self.convergence_diagnosticx.append(i - self.convergence_n/2.)
 
-    def _burnin_convergence_test(self, i, sampler, nburn):
-        if i < self.convergence_burnin_fraction*nburn:
-            return False
-        if np.mod(i+1, self.convergence_period) != 0:
+        if test and np.max(c) < R:
+            return True
+        else:
             return False
-        c = self._get_convergence_statistic(i, sampler)
-        if np.all(c < self.convergence_threshold):
-            self.convergence_number += 1
+
+    def _test_convergence(self, i, sampler, **kwargs):
+        if np.mod(i+1, self.convergence_n) == 0:
+            return self._get_convergence_test(i, sampler, **kwargs)
         else:
-            self.convergence_number = 0
-        if self.convergence_early_stopping:
-            return self.convergence_number > self.convergence_threshold_number
-
-    def _prod_convergence_test(self, i, sampler, nburn):
-        testA = i > nburn + self.convergence_length
-        testB = np.mod(i+1, self.convergence_period) == 0
-        if testA and testB:
-            self._get_convergence_statistic(i, sampler)
-
-    def _check_production_convergence(self, k):
-        bools = np.any(
-            np.array(self.convergence_diagnostic)[k:, :]
-            > self.convergence_prod_threshold, axis=1)
-        if np.any(bools):
-            logging.warning(
-                '{} convergence tests in the production run of {} failed'
-                .format(np.sum(bools), len(bools)))
+            return False
+
+    def _run_sampler_with_conv_test(self, sampler, p0, nprod=0, nburn=0):
+        logging.info('Running {} burn-in steps with convergence testing'
+                     .format(nburn))
+        iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
+        for i, output in enumerate(iterator):
+            if self._test_convergence(i, sampler, test=True,
+                                      **self.convergence_kwargs):
+                logging.info(
+                    'Converged at {} before max number {} of steps reached'
+                    .format(i, nburn))
+                self.convergence_idx = i
+                break
+        iterator.close()
+        logging.info('Running {} production steps'.format(nprod))
+        j = nburn
+        iterator = tqdm(sampler.sample(output[0], iterations=nprod),
+                        total=nprod)
+        for result in iterator:
+            self._test_convergence(j, sampler, test=False,
+                                   **self.convergence_kwargs)
+            j += 1
+        return sampler
 
     def _run_sampler(self, sampler, p0, nprod=0, nburn=0):
-        if hasattr(self, 'convergence_period'):
-            logging.info('Running {} burn-in steps with convergence testing'
-                         .format(nburn))
-            iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
-            for i, output in enumerate(iterator):
-                if self._burnin_convergence_test(i, sampler, nburn):
-                    logging.info(
-                        'Converged at {} before max number {} of steps reached'
-                        .format(i, nburn))
-                    self.convergence_idx = i
-                    break
-            iterator.close()
-            logging.info('Running {} production steps'.format(nprod))
-            j = nburn
-            k = len(self.convergence_diagnostic)
-            for result in tqdm(sampler.sample(output[0], iterations=nprod),
-                               total=nprod):
-                self._prod_convergence_test(j, sampler, nburn)
-                j += 1
-            self._check_production_convergence(k)
+        if hasattr(self, 'convergence_n'):
+            self._run_sampler_with_conv_test(sampler, p0, nprod, nburn)
         else:
             for result in tqdm(sampler.sample(p0, iterations=nburn+nprod),
                                total=nburn+nprod):
@@ -956,9 +945,11 @@ class MCMCSearch(core.BaseSearchClass):
                                 zorder=-10)
                         ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-C0',
                                 zorder=-10)
-                        ax.set_ylabel('PSRF')
+                        if self.convergence_test_type == 'autocorr':
+                            ax.set_ylabel(r'$\tau_\mathrm{exp}$')
+                        elif self.convergence_test_type == 'GR':
+                            ax.set_ylabel('PSRF')
                         ax.ticklabel_format(useOffset=False)
-                        ax.set_ylim(0.5, self.convergence_plot_upper_lim)
             else:
                 axes[0].ticklabel_format(useOffset=False, axis='y')
                 cs = chain[:, :, temp].T