Skip to content
Snippets Groups Projects
Commit afbb9815 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Reorganisation and cleanup of convergence testing

- Adds autocorrelation attempt (using PR 223 to emcee)
parent 50e20741
No related branches found
No related tags found
No related merge requests found
...@@ -239,61 +239,63 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -239,61 +239,63 @@ class MCMCSearch(core.BaseSearchClass):
pass pass
return sampler return sampler
def setup_convergence_testing( def setup_burnin_convergence_testing(
self, convergence_period=10, convergence_length=10, self, n=10, test_type='autocorr', windowed=False, **kwargs):
convergence_burnin_fraction=0.25, convergence_threshold_number=10,
convergence_threshold=1.2, convergence_prod_threshold=2,
convergence_plot_upper_lim=2, convergence_early_stopping=True):
""" """
If called, convergence testing is used during the MCMC simulation If called, convergence testing is used during the MCMC simulation
This uses the Gelmanr-Rubin statistic based on the ratio of between and
within walkers variance. The original statistic was developed for
multiple (independent) MCMC simulations, in this context we simply use
the walkers
Parameters Parameters
---------- ----------
convergence_period: int n: int
period (in number of steps) at which to test convergence Number of steps after which to test convergence
convergence_length: int test_type: str ['autocorr', 'GR']
number of steps to use in testing convergence - this should be If 'autocorr' use the exponential autocorrelation time (kwargs
large enough to measure the variance, but if it is too long passed to `get_autocorr_convergence`). If 'GR' use the Gelman-Rubin
this will result in incorect early convergence tests statistic (kwargs passed to `get_GR_convergence`)
convergence_burnin_fraction: float [0, 1] windowed: bool
the fraction of the burn-in period after which to start testing If True, only calculate the convergence test in a window of length
convergence_threshold_number: int `n`
the number of consecutive times where the test passes after which
to break the burn-in and go to production
convergence_threshold: float
the threshold to use in diagnosing convergence. Gelman & Rubin
recomend a value of 1.2, 1.1 for strict convergence
convergence_prod_threshold: float
the threshold to test the production values with
convergence_plot_upper_lim: float
the upper limit to use in the diagnostic plot
convergence_early_stopping: bool
if true, stop the burnin early if convergence is reached
""" """
if convergence_length > convergence_period:
raise ValueError('convergence_length must be < convergence_period')
logging.info('Setting up convergence testing') logging.info('Setting up convergence testing')
self.convergence_length = convergence_length self.convergence_n = n
self.convergence_period = convergence_period self.convergence_windowed = windowed
self.convergence_burnin_fraction = convergence_burnin_fraction self.convergence_test_type = test_type
self.convergence_prod_threshold = convergence_prod_threshold self.convergence_kwargs = kwargs
self.convergence_diagnostic = [] self.convergence_diagnostic = []
self.convergence_diagnosticx = [] self.convergence_diagnosticx = []
self.convergence_threshold_number = convergence_threshold_number if test_type in ['autocorr']:
self.convergence_threshold = convergence_threshold self._get_convergence_test = self.test_autocorr_convergence
self.convergence_number = 0 elif test_type in ['GR']:
self.convergence_plot_upper_lim = convergence_plot_upper_lim self._get_convergence_test= self.test_GR_convergence
self.convergence_early_stopping = convergence_early_stopping else:
raise ValueError('test_type {} not understood'.format(test_type))
def _get_convergence_statistic(self, i, sampler):
s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :] def test_autocorr_convergence(self, i, sampler, test=True, n_cut=5):
N = float(self.convergence_length) try:
acors = np.zeros((self.ntemps, self.ndim))
for temp in range(self.ntemps):
if self.convergence_windowed:
j = i-self.convergence_n
else:
j = 0
x = np.mean(sampler.chain[temp, :, j:i, :], axis=0)
acors[temp, :] = emcee.autocorr.exponential_time(x)
c = np.max(acors, axis=0)
except emcee.autocorr.AutocorrError:
c = np.zeros(self.ndim) + np.nan
self.convergence_diagnosticx.append(i - self.convergence_n/2.)
self.convergence_diagnostic.append(list(c))
if test:
return i > n_cut * np.max(c)
def test_GR_convergence(self, i, sampler, test=True, R=1.1):
if self.convergence_windowed:
s = sampler.chain[0, :, i-self.convergence_n+1:i+1, :]
else:
s = sampler.chain[0, :, :i+1, :]
N = float(self.convergence_n)
M = float(self.nwalkers) M = float(self.nwalkers)
W = np.mean(np.var(s, axis=1), axis=0) W = np.mean(np.var(s, axis=1), axis=0)
per_walker_mean = np.mean(s, axis=1) per_walker_mean = np.mean(s, axis=1)
...@@ -302,44 +304,26 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -302,44 +304,26 @@ class MCMCSearch(core.BaseSearchClass):
Vhat = (N-1)/N * W + (M+1)/(M*N) * B Vhat = (N-1)/N * W + (M+1)/(M*N) * B
c = np.sqrt(Vhat/W) c = np.sqrt(Vhat/W)
self.convergence_diagnostic.append(c) self.convergence_diagnostic.append(c)
self.convergence_diagnosticx.append(i - self.convergence_length/2) self.convergence_diagnosticx.append(i - self.convergence_n/2.)
return c
def _burnin_convergence_test(self, i, sampler, nburn): if test and np.max(c) < R:
if i < self.convergence_burnin_fraction*nburn: return True
else:
return False return False
if np.mod(i+1, self.convergence_period) != 0:
def _test_convergence(self, i, sampler, **kwargs):
if np.mod(i+1, self.convergence_n) == 0:
return self._get_convergence_test(i, sampler, **kwargs)
else:
return False return False
c = self._get_convergence_statistic(i, sampler)
if np.all(c < self.convergence_threshold):
self.convergence_number += 1
else:
self.convergence_number = 0
if self.convergence_early_stopping:
return self.convergence_number > self.convergence_threshold_number
def _prod_convergence_test(self, i, sampler, nburn):
testA = i > nburn + self.convergence_length
testB = np.mod(i+1, self.convergence_period) == 0
if testA and testB:
self._get_convergence_statistic(i, sampler)
def _check_production_convergence(self, k):
bools = np.any(
np.array(self.convergence_diagnostic)[k:, :]
> self.convergence_prod_threshold, axis=1)
if np.any(bools):
logging.warning(
'{} convergence tests in the production run of {} failed'
.format(np.sum(bools), len(bools)))
def _run_sampler(self, sampler, p0, nprod=0, nburn=0): def _run_sampler_with_conv_test(self, sampler, p0, nprod=0, nburn=0):
if hasattr(self, 'convergence_period'):
logging.info('Running {} burn-in steps with convergence testing' logging.info('Running {} burn-in steps with convergence testing'
.format(nburn)) .format(nburn))
iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn) iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
for i, output in enumerate(iterator): for i, output in enumerate(iterator):
if self._burnin_convergence_test(i, sampler, nburn): if self._test_convergence(i, sampler, test=True,
**self.convergence_kwargs):
logging.info( logging.info(
'Converged at {} before max number {} of steps reached' 'Converged at {} before max number {} of steps reached'
.format(i, nburn)) .format(i, nburn))
...@@ -348,12 +332,17 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -348,12 +332,17 @@ class MCMCSearch(core.BaseSearchClass):
iterator.close() iterator.close()
logging.info('Running {} production steps'.format(nprod)) logging.info('Running {} production steps'.format(nprod))
j = nburn j = nburn
k = len(self.convergence_diagnostic) iterator = tqdm(sampler.sample(output[0], iterations=nprod),
for result in tqdm(sampler.sample(output[0], iterations=nprod), total=nprod)
total=nprod): for result in iterator:
self._prod_convergence_test(j, sampler, nburn) self._test_convergence(j, sampler, test=False,
**self.convergence_kwargs)
j += 1 j += 1
self._check_production_convergence(k) return sampler
def _run_sampler(self, sampler, p0, nprod=0, nburn=0):
if hasattr(self, 'convergence_n'):
self._run_sampler_with_conv_test(sampler, p0, nprod, nburn)
else: else:
for result in tqdm(sampler.sample(p0, iterations=nburn+nprod), for result in tqdm(sampler.sample(p0, iterations=nburn+nprod),
total=nburn+nprod): total=nburn+nprod):
...@@ -956,9 +945,11 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -956,9 +945,11 @@ class MCMCSearch(core.BaseSearchClass):
zorder=-10) zorder=-10)
ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-C0', ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-C0',
zorder=-10) zorder=-10)
if self.convergence_test_type == 'autocorr':
ax.set_ylabel(r'$\tau_\mathrm{exp}$')
elif self.convergence_test_type == 'GR':
ax.set_ylabel('PSRF') ax.set_ylabel('PSRF')
ax.ticklabel_format(useOffset=False) ax.ticklabel_format(useOffset=False)
ax.set_ylim(0.5, self.convergence_plot_upper_lim)
else: else:
axes[0].ticklabel_format(useOffset=False, axis='y') axes[0].ticklabel_format(useOffset=False, axis='y')
cs = chain[:, :, temp].T cs = chain[:, :, temp].T
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment