From 286b576c0843e6b0ad487eb05815ff6b348e3fca Mon Sep 17 00:00:00 2001 From: "gregory.ashton" <gregory.ashton@ligo.org> Date: Wed, 3 May 2017 12:14:08 +0200 Subject: [PATCH] Add early stopping and clean up PSRF calculation --- pyfstat/mcmc_based_searches.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index 905dafb..bd76859 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -229,7 +229,7 @@ class MCMCSearch(core.BaseSearchClass): self, convergence_period=10, convergence_length=10, convergence_burnin_fraction=0.25, convergence_threshold_number=10, convergence_threshold=1.2, convergence_prod_threshold=2, - convergence_plot_upper_lim=2): + convergence_plot_upper_lim=2, convergence_early_stopping=True): """ If called, convergence testing is used during the MCMC simulation @@ -258,6 +258,8 @@ class MCMCSearch(core.BaseSearchClass): the threshold to test the production values with convergence_plot_upper_lim: float the upper limit to use in the diagnostic plot + convergence_early_stopping: bool + if true, stop the burnin early if convergence is reached """ if convergence_length > convergence_period: @@ -273,18 +275,18 @@ class MCMCSearch(core.BaseSearchClass): self.convergence_threshold = convergence_threshold self.convergence_number = 0 self.convergence_plot_upper_lim = convergence_plot_upper_lim + self.convergence_early_stopping = convergence_early_stopping def _get_convergence_statistic(self, i, sampler): s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :] - within_std = np.mean(np.var(s, axis=1), axis=0) + N = float(self.convergence_length) + M = float(self.nwalkers) + W = np.mean(np.var(s, axis=1), axis=0) per_walker_mean = np.mean(s, axis=1) mean = np.mean(per_walker_mean, axis=0) - between_std = np.sqrt(np.mean((per_walker_mean-mean)**2, axis=0)) - W = within_std - B_over_n = between_std**2 / self.convergence_period - Vhat = ((self.convergence_period-1.)/self.convergence_period * W - + B_over_n + B_over_n / float(self.nwalkers)) - c = np.sqrt(Vhat/W) + B = N / (M-1.) * np.sum((per_walker_mean-mean)**2, axis=0) + Vhat = (N-1)/N * W + (M+1)/(M*N) * B + c = Vhat/W self.convergence_diagnostic.append(c) self.convergence_diagnosticx.append(i - self.convergence_length/2) return c @@ -299,7 +301,8 @@ class MCMCSearch(core.BaseSearchClass): self.convergence_number += 1 else: self.convergence_number = 0 - return self.convergence_number > self.convergence_threshold_number + if self.convergence_early_stopping: + return self.convergence_number > self.convergence_threshold_number def _prod_convergence_test(self, i, sampler, nburn): testA = i > nburn + self.convergence_length @@ -873,7 +876,7 @@ class MCMCSearch(core.BaseSearchClass): ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-b') ax.set_ylabel('PSRF') ax.ticklabel_format(useOffset=False) - ax.set_ylim(1, self.convergence_plot_upper_lim) + ax.set_ylim(0.5, self.convergence_plot_upper_lim) else: axes[0].ticklabel_format(useOffset=False, axis='y') cs = chain[:, :, temp].T -- GitLab