Commit 4be56f25 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improvements/fixes to the convergence code

- Fixes missing sqrt in the convergence test
- fixes error in periord
- prevents plotting between burn and prod
- allow user to specify upper limit
- refactor code
parent 012cf406
......@@ -216,7 +216,8 @@ class MCMCSearch(BaseSearchClass):
def setup_convergence_testing(
self, convergence_period=10, convergence_length=10,
convergence_burnin_fraction=0.25, convergence_threshold_number=10,
convergence_threshold=1.2, convergence_prod_threshold=2):
convergence_threshold=1.2, convergence_prod_threshold=2,
convergence_plot_upper_lim=2):
"""
If called, convergence testing is used during the MCMC simulation
......@@ -243,6 +244,8 @@ class MCMCSearch(BaseSearchClass):
recomend a value of 1.2, 1.1 for strict convergence
convergence_prod_threshold: float
the threshold to test the production values with
convergence_plot_upper_lim: float
the upper limit to use in the diagnostic plot
"""
if convergence_length > convergence_period:
......@@ -257,6 +260,7 @@ class MCMCSearch(BaseSearchClass):
self.convergence_threshold_number = convergence_threshold_number
self.convergence_threshold = convergence_threshold
self.convergence_number = 0
self.convergence_plot_upper_lim = convergence_plot_upper_lim
def get_convergence_statistic(self, i, sampler):
s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
......@@ -268,15 +272,15 @@ class MCMCSearch(BaseSearchClass):
B_over_n = between_std**2 / self.convergence_period
Vhat = ((self.convergence_period-1.)/self.convergence_period * W
+ B_over_n + B_over_n / float(self.nwalkers))
c = Vhat/W
c = np.sqrt(Vhat/W)
self.convergence_diagnostic.append(c)
self.convergence_diagnosticx.append(i - self.convergence_period/2)
self.convergence_diagnosticx.append(i - self.convergence_length/2)
return c
def convergence_test(self, i, sampler, nburn):
def burnin_convergence_test(self, i, sampler, nburn):
if i < self.convergence_burnin_fraction*nburn:
return False
if np.mod(i+1, self.convergence_period) == 0:
if np.mod(i+1, self.convergence_period) != 0:
return False
c = self.get_convergence_statistic(i, sampler)
if np.all(c < self.convergence_threshold):
......@@ -285,6 +289,12 @@ class MCMCSearch(BaseSearchClass):
self.convergence_number = 0
return self.convergence_number > self.convergence_threshold_number
def prod_convergence_test(self, i, sampler, nburn):
testA = i > nburn + self.convergence_length
testB = np.mod(i+1, self.convergence_period) == 0
if testA and testB:
self.get_convergence_statistic(i, sampler)
def check_production_convergence(self, k):
bools = np.any(
np.array(self.convergence_diagnostic)[k:, :]
......@@ -296,26 +306,23 @@ class MCMCSearch(BaseSearchClass):
def run_sampler(self, sampler, p0, nprod=0, nburn=0):
if hasattr(self, 'convergence_period'):
converged = False
logging.info('Running {} burn-in steps with convergence testing'
.format(nburn))
iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
for i, output in enumerate(iterator):
if converged:
if self.burnin_convergence_test(i, sampler, nburn):
logging.info(
'Converged at {} before max number {} of steps reached'
.format(i, nburn))
self.convergence_idx = i
break
else:
converged = self.convergence_test(i, sampler, nburn)
iterator.close()
logging.info('Running {} production steps'.format(nprod))
j = nburn
k = len(self.convergence_diagnostic)
for result in tqdm(sampler.sample(output[0], iterations=nprod),
total=nprod):
self.get_convergence_statistic(j, sampler)
self.prod_convergence_test(j, sampler, nburn)
j += 1
self.check_production_convergence(k)
return sampler
......@@ -365,7 +372,7 @@ class MCMCSearch(BaseSearchClass):
**kwargs)
fig.tight_layout()
fig.savefig('{}/{}_init_{}_walkers.png'.format(
self.outdir, self.label, j), dpi=200)
self.outdir, self.label, j), dpi=400)
p0 = self.get_new_p0(sampler)
p0 = self.apply_corrections_to_p0(p0)
......@@ -696,9 +703,12 @@ class MCMCSearch(BaseSearchClass):
ax = axes[i].twinx()
c_x = np.array(self.convergence_diagnosticx)
c_y = np.array(self.convergence_diagnostic)
ax.plot(c_x, c_y[:, i], '-b')
break_idx = np.argmin(np.abs(c_x - burnin_idx))
ax.plot(c_x[:break_idx], c_y[:break_idx, i], '-b')
ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-b')
ax.set_ylabel('PSRF')
ax.ticklabel_format(useOffset=False)
ax.set_ylim(1, 5)
ax.set_ylim(1, self.convergence_plot_upper_lim)
else:
axes[0].ticklabel_format(useOffset=False, axis='y')
cs = chain[:, :, temp].T
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment