Commit 4be56f25 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improvements/fixes to the convergence code

- Fixes missing sqrt in the convergence test
- fixes error in periord
- prevents plotting between burn and prod
- allow user to specify upper limit
- refactor code
parent 012cf406
...@@ -216,7 +216,8 @@ class MCMCSearch(BaseSearchClass): ...@@ -216,7 +216,8 @@ class MCMCSearch(BaseSearchClass):
def setup_convergence_testing( def setup_convergence_testing(
self, convergence_period=10, convergence_length=10, self, convergence_period=10, convergence_length=10,
convergence_burnin_fraction=0.25, convergence_threshold_number=10, convergence_burnin_fraction=0.25, convergence_threshold_number=10,
convergence_threshold=1.2, convergence_prod_threshold=2): convergence_threshold=1.2, convergence_prod_threshold=2,
convergence_plot_upper_lim=2):
""" """
If called, convergence testing is used during the MCMC simulation If called, convergence testing is used during the MCMC simulation
...@@ -243,6 +244,8 @@ class MCMCSearch(BaseSearchClass): ...@@ -243,6 +244,8 @@ class MCMCSearch(BaseSearchClass):
recomend a value of 1.2, 1.1 for strict convergence recomend a value of 1.2, 1.1 for strict convergence
convergence_prod_threshold: float convergence_prod_threshold: float
the threshold to test the production values with the threshold to test the production values with
convergence_plot_upper_lim: float
the upper limit to use in the diagnostic plot
""" """
if convergence_length > convergence_period: if convergence_length > convergence_period:
...@@ -257,6 +260,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -257,6 +260,7 @@ class MCMCSearch(BaseSearchClass):
self.convergence_threshold_number = convergence_threshold_number self.convergence_threshold_number = convergence_threshold_number
self.convergence_threshold = convergence_threshold self.convergence_threshold = convergence_threshold
self.convergence_number = 0 self.convergence_number = 0
self.convergence_plot_upper_lim = convergence_plot_upper_lim
def get_convergence_statistic(self, i, sampler): def get_convergence_statistic(self, i, sampler):
s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :] s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
...@@ -268,15 +272,15 @@ class MCMCSearch(BaseSearchClass): ...@@ -268,15 +272,15 @@ class MCMCSearch(BaseSearchClass):
B_over_n = between_std**2 / self.convergence_period B_over_n = between_std**2 / self.convergence_period
Vhat = ((self.convergence_period-1.)/self.convergence_period * W Vhat = ((self.convergence_period-1.)/self.convergence_period * W
+ B_over_n + B_over_n / float(self.nwalkers)) + B_over_n + B_over_n / float(self.nwalkers))
c = Vhat/W c = np.sqrt(Vhat/W)
self.convergence_diagnostic.append(c) self.convergence_diagnostic.append(c)
self.convergence_diagnosticx.append(i - self.convergence_period/2) self.convergence_diagnosticx.append(i - self.convergence_length/2)
return c return c
def convergence_test(self, i, sampler, nburn): def burnin_convergence_test(self, i, sampler, nburn):
if i < self.convergence_burnin_fraction*nburn: if i < self.convergence_burnin_fraction*nburn:
return False return False
if np.mod(i+1, self.convergence_period) == 0: if np.mod(i+1, self.convergence_period) != 0:
return False return False
c = self.get_convergence_statistic(i, sampler) c = self.get_convergence_statistic(i, sampler)
if np.all(c < self.convergence_threshold): if np.all(c < self.convergence_threshold):
...@@ -285,6 +289,12 @@ class MCMCSearch(BaseSearchClass): ...@@ -285,6 +289,12 @@ class MCMCSearch(BaseSearchClass):
self.convergence_number = 0 self.convergence_number = 0
return self.convergence_number > self.convergence_threshold_number return self.convergence_number > self.convergence_threshold_number
def prod_convergence_test(self, i, sampler, nburn):
testA = i > nburn + self.convergence_length
testB = np.mod(i+1, self.convergence_period) == 0
if testA and testB:
self.get_convergence_statistic(i, sampler)
def check_production_convergence(self, k): def check_production_convergence(self, k):
bools = np.any( bools = np.any(
np.array(self.convergence_diagnostic)[k:, :] np.array(self.convergence_diagnostic)[k:, :]
...@@ -296,26 +306,23 @@ class MCMCSearch(BaseSearchClass): ...@@ -296,26 +306,23 @@ class MCMCSearch(BaseSearchClass):
def run_sampler(self, sampler, p0, nprod=0, nburn=0): def run_sampler(self, sampler, p0, nprod=0, nburn=0):
if hasattr(self, 'convergence_period'): if hasattr(self, 'convergence_period'):
converged = False
logging.info('Running {} burn-in steps with convergence testing' logging.info('Running {} burn-in steps with convergence testing'
.format(nburn)) .format(nburn))
iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn) iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
for i, output in enumerate(iterator): for i, output in enumerate(iterator):
if converged: if self.burnin_convergence_test(i, sampler, nburn):
logging.info( logging.info(
'Converged at {} before max number {} of steps reached' 'Converged at {} before max number {} of steps reached'
.format(i, nburn)) .format(i, nburn))
self.convergence_idx = i self.convergence_idx = i
break break
else:
converged = self.convergence_test(i, sampler, nburn)
iterator.close() iterator.close()
logging.info('Running {} production steps'.format(nprod)) logging.info('Running {} production steps'.format(nprod))
j = nburn j = nburn
k = len(self.convergence_diagnostic) k = len(self.convergence_diagnostic)
for result in tqdm(sampler.sample(output[0], iterations=nprod), for result in tqdm(sampler.sample(output[0], iterations=nprod),
total=nprod): total=nprod):
self.get_convergence_statistic(j, sampler) self.prod_convergence_test(j, sampler, nburn)
j += 1 j += 1
self.check_production_convergence(k) self.check_production_convergence(k)
return sampler return sampler
...@@ -365,7 +372,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -365,7 +372,7 @@ class MCMCSearch(BaseSearchClass):
**kwargs) **kwargs)
fig.tight_layout() fig.tight_layout()
fig.savefig('{}/{}_init_{}_walkers.png'.format( fig.savefig('{}/{}_init_{}_walkers.png'.format(
self.outdir, self.label, j), dpi=200) self.outdir, self.label, j), dpi=400)
p0 = self.get_new_p0(sampler) p0 = self.get_new_p0(sampler)
p0 = self.apply_corrections_to_p0(p0) p0 = self.apply_corrections_to_p0(p0)
...@@ -696,9 +703,12 @@ class MCMCSearch(BaseSearchClass): ...@@ -696,9 +703,12 @@ class MCMCSearch(BaseSearchClass):
ax = axes[i].twinx() ax = axes[i].twinx()
c_x = np.array(self.convergence_diagnosticx) c_x = np.array(self.convergence_diagnosticx)
c_y = np.array(self.convergence_diagnostic) c_y = np.array(self.convergence_diagnostic)
ax.plot(c_x, c_y[:, i], '-b') break_idx = np.argmin(np.abs(c_x - burnin_idx))
ax.plot(c_x[:break_idx], c_y[:break_idx, i], '-b')
ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-b')
ax.set_ylabel('PSRF')
ax.ticklabel_format(useOffset=False) ax.ticklabel_format(useOffset=False)
ax.set_ylim(1, 5) ax.set_ylim(1, self.convergence_plot_upper_lim)
else: else:
axes[0].ticklabel_format(useOffset=False, axis='y') axes[0].ticklabel_format(useOffset=False, axis='y')
cs = chain[:, :, temp].T cs = chain[:, :, temp].T
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment