From 286b576c0843e6b0ad487eb05815ff6b348e3fca Mon Sep 17 00:00:00 2001
From: "gregory.ashton" <gregory.ashton@ligo.org>
Date: Wed, 3 May 2017 12:14:08 +0200
Subject: [PATCH] Add early stopping and clean up PSRF calculation

---
 pyfstat/mcmc_based_searches.py | 23 +++++++++++++----------
 1 file changed, 13 insertions(+), 10 deletions(-)

diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index 905dafb..bd76859 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -229,7 +229,7 @@ class MCMCSearch(core.BaseSearchClass):
             self, convergence_period=10, convergence_length=10,
             convergence_burnin_fraction=0.25, convergence_threshold_number=10,
             convergence_threshold=1.2, convergence_prod_threshold=2,
-            convergence_plot_upper_lim=2):
+            convergence_plot_upper_lim=2, convergence_early_stopping=True):
         """
         If called, convergence testing is used during the MCMC simulation
 
@@ -258,6 +258,8 @@ class MCMCSearch(core.BaseSearchClass):
             the threshold to test the production values with
         convergence_plot_upper_lim: float
             the upper limit to use in the diagnostic plot
+        convergence_early_stopping: bool
+            if true, stop the burnin early if convergence is reached
         """
 
         if convergence_length > convergence_period:
@@ -273,18 +275,18 @@ class MCMCSearch(core.BaseSearchClass):
         self.convergence_threshold = convergence_threshold
         self.convergence_number = 0
         self.convergence_plot_upper_lim = convergence_plot_upper_lim
+        self.convergence_early_stopping = convergence_early_stopping
 
     def _get_convergence_statistic(self, i, sampler):
         s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
-        within_std = np.mean(np.var(s, axis=1), axis=0)
+        N = float(self.convergence_length)
+        M = float(self.nwalkers)
+        W = np.mean(np.var(s, axis=1), axis=0)
         per_walker_mean = np.mean(s, axis=1)
         mean = np.mean(per_walker_mean, axis=0)
-        between_std = np.sqrt(np.mean((per_walker_mean-mean)**2, axis=0))
-        W = within_std
-        B_over_n = between_std**2 / self.convergence_period
-        Vhat = ((self.convergence_period-1.)/self.convergence_period * W
-                + B_over_n + B_over_n / float(self.nwalkers))
-        c = np.sqrt(Vhat/W)
+        B = N / (M-1.) * np.sum((per_walker_mean-mean)**2, axis=0)
+        Vhat = (N-1)/N * W + (M+1)/(M*N) * B
+        c = Vhat/W
         self.convergence_diagnostic.append(c)
         self.convergence_diagnosticx.append(i - self.convergence_length/2)
         return c
@@ -299,7 +301,8 @@ class MCMCSearch(core.BaseSearchClass):
             self.convergence_number += 1
         else:
             self.convergence_number = 0
-        return self.convergence_number > self.convergence_threshold_number
+        if self.convergence_early_stopping:
+            return self.convergence_number > self.convergence_threshold_number
 
     def _prod_convergence_test(self, i, sampler, nburn):
         testA = i > nburn + self.convergence_length
@@ -873,7 +876,7 @@ class MCMCSearch(core.BaseSearchClass):
                         ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-b')
                         ax.set_ylabel('PSRF')
                         ax.ticklabel_format(useOffset=False)
-                        ax.set_ylim(1, self.convergence_plot_upper_lim)
+                        ax.set_ylim(0.5, self.convergence_plot_upper_lim)
             else:
                 axes[0].ticklabel_format(useOffset=False, axis='y')
                 cs = chain[:, :, temp].T
-- 
GitLab