Skip to content
Snippets Groups Projects
Commit 286b576c authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Add early stopping and clean up PSRF calculation

parent b5ea8ff3
Branches
No related tags found
No related merge requests found
...@@ -229,7 +229,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -229,7 +229,7 @@ class MCMCSearch(core.BaseSearchClass):
self, convergence_period=10, convergence_length=10, self, convergence_period=10, convergence_length=10,
convergence_burnin_fraction=0.25, convergence_threshold_number=10, convergence_burnin_fraction=0.25, convergence_threshold_number=10,
convergence_threshold=1.2, convergence_prod_threshold=2, convergence_threshold=1.2, convergence_prod_threshold=2,
convergence_plot_upper_lim=2): convergence_plot_upper_lim=2, convergence_early_stopping=True):
""" """
If called, convergence testing is used during the MCMC simulation If called, convergence testing is used during the MCMC simulation
...@@ -258,6 +258,8 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -258,6 +258,8 @@ class MCMCSearch(core.BaseSearchClass):
the threshold to test the production values with the threshold to test the production values with
convergence_plot_upper_lim: float convergence_plot_upper_lim: float
the upper limit to use in the diagnostic plot the upper limit to use in the diagnostic plot
convergence_early_stopping: bool
if true, stop the burnin early if convergence is reached
""" """
if convergence_length > convergence_period: if convergence_length > convergence_period:
...@@ -273,18 +275,18 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -273,18 +275,18 @@ class MCMCSearch(core.BaseSearchClass):
self.convergence_threshold = convergence_threshold self.convergence_threshold = convergence_threshold
self.convergence_number = 0 self.convergence_number = 0
self.convergence_plot_upper_lim = convergence_plot_upper_lim self.convergence_plot_upper_lim = convergence_plot_upper_lim
self.convergence_early_stopping = convergence_early_stopping
def _get_convergence_statistic(self, i, sampler): def _get_convergence_statistic(self, i, sampler):
s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :] s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
within_std = np.mean(np.var(s, axis=1), axis=0) N = float(self.convergence_length)
M = float(self.nwalkers)
W = np.mean(np.var(s, axis=1), axis=0)
per_walker_mean = np.mean(s, axis=1) per_walker_mean = np.mean(s, axis=1)
mean = np.mean(per_walker_mean, axis=0) mean = np.mean(per_walker_mean, axis=0)
between_std = np.sqrt(np.mean((per_walker_mean-mean)**2, axis=0)) B = N / (M-1.) * np.sum((per_walker_mean-mean)**2, axis=0)
W = within_std Vhat = (N-1)/N * W + (M+1)/(M*N) * B
B_over_n = between_std**2 / self.convergence_period c = Vhat/W
Vhat = ((self.convergence_period-1.)/self.convergence_period * W
+ B_over_n + B_over_n / float(self.nwalkers))
c = np.sqrt(Vhat/W)
self.convergence_diagnostic.append(c) self.convergence_diagnostic.append(c)
self.convergence_diagnosticx.append(i - self.convergence_length/2) self.convergence_diagnosticx.append(i - self.convergence_length/2)
return c return c
...@@ -299,6 +301,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -299,6 +301,7 @@ class MCMCSearch(core.BaseSearchClass):
self.convergence_number += 1 self.convergence_number += 1
else: else:
self.convergence_number = 0 self.convergence_number = 0
if self.convergence_early_stopping:
return self.convergence_number > self.convergence_threshold_number return self.convergence_number > self.convergence_threshold_number
def _prod_convergence_test(self, i, sampler, nburn): def _prod_convergence_test(self, i, sampler, nburn):
...@@ -873,7 +876,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -873,7 +876,7 @@ class MCMCSearch(core.BaseSearchClass):
ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-b') ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-b')
ax.set_ylabel('PSRF') ax.set_ylabel('PSRF')
ax.ticklabel_format(useOffset=False) ax.ticklabel_format(useOffset=False)
ax.set_ylim(1, self.convergence_plot_upper_lim) ax.set_ylim(0.5, self.convergence_plot_upper_lim)
else: else:
axes[0].ticklabel_format(useOffset=False, axis='y') axes[0].ticklabel_format(useOffset=False, axis='y')
cs = chain[:, :, temp].T cs = chain[:, :, temp].T
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment