Commit 286b576c authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Add early stopping and clean up PSRF calculation

parent b5ea8ff3
......@@ -229,7 +229,7 @@ class MCMCSearch(core.BaseSearchClass):
self, convergence_period=10, convergence_length=10,
convergence_burnin_fraction=0.25, convergence_threshold_number=10,
convergence_threshold=1.2, convergence_prod_threshold=2,
convergence_plot_upper_lim=2):
convergence_plot_upper_lim=2, convergence_early_stopping=True):
"""
If called, convergence testing is used during the MCMC simulation
......@@ -258,6 +258,8 @@ class MCMCSearch(core.BaseSearchClass):
the threshold to test the production values with
convergence_plot_upper_lim: float
the upper limit to use in the diagnostic plot
convergence_early_stopping: bool
if true, stop the burnin early if convergence is reached
"""
if convergence_length > convergence_period:
......@@ -273,18 +275,18 @@ class MCMCSearch(core.BaseSearchClass):
self.convergence_threshold = convergence_threshold
self.convergence_number = 0
self.convergence_plot_upper_lim = convergence_plot_upper_lim
self.convergence_early_stopping = convergence_early_stopping
def _get_convergence_statistic(self, i, sampler):
s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
within_std = np.mean(np.var(s, axis=1), axis=0)
N = float(self.convergence_length)
M = float(self.nwalkers)
W = np.mean(np.var(s, axis=1), axis=0)
per_walker_mean = np.mean(s, axis=1)
mean = np.mean(per_walker_mean, axis=0)
between_std = np.sqrt(np.mean((per_walker_mean-mean)**2, axis=0))
W = within_std
B_over_n = between_std**2 / self.convergence_period
Vhat = ((self.convergence_period-1.)/self.convergence_period * W
+ B_over_n + B_over_n / float(self.nwalkers))
c = np.sqrt(Vhat/W)
B = N / (M-1.) * np.sum((per_walker_mean-mean)**2, axis=0)
Vhat = (N-1)/N * W + (M+1)/(M*N) * B
c = Vhat/W
self.convergence_diagnostic.append(c)
self.convergence_diagnosticx.append(i - self.convergence_length/2)
return c
......@@ -299,7 +301,8 @@ class MCMCSearch(core.BaseSearchClass):
self.convergence_number += 1
else:
self.convergence_number = 0
return self.convergence_number > self.convergence_threshold_number
if self.convergence_early_stopping:
return self.convergence_number > self.convergence_threshold_number
def _prod_convergence_test(self, i, sampler, nburn):
testA = i > nburn + self.convergence_length
......@@ -873,7 +876,7 @@ class MCMCSearch(core.BaseSearchClass):
ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-b')
ax.set_ylabel('PSRF')
ax.ticklabel_format(useOffset=False)
ax.set_ylim(1, self.convergence_plot_upper_lim)
ax.set_ylim(0.5, self.convergence_plot_upper_lim)
else:
axes[0].ticklabel_format(useOffset=False, axis='y')
cs = chain[:, :, temp].T
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment