Commit afbb9815 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Reorganisation and cleanup of convergence testing

- Adds autocorrelation attempt (using PR 223 to emcee)
parent 50e20741
......@@ -239,61 +239,63 @@ class MCMCSearch(core.BaseSearchClass):
pass
return sampler
def setup_convergence_testing(
self, convergence_period=10, convergence_length=10,
convergence_burnin_fraction=0.25, convergence_threshold_number=10,
convergence_threshold=1.2, convergence_prod_threshold=2,
convergence_plot_upper_lim=2, convergence_early_stopping=True):
def setup_burnin_convergence_testing(
self, n=10, test_type='autocorr', windowed=False, **kwargs):
"""
If called, convergence testing is used during the MCMC simulation
This uses the Gelmanr-Rubin statistic based on the ratio of between and
within walkers variance. The original statistic was developed for
multiple (independent) MCMC simulations, in this context we simply use
the walkers
Parameters
----------
convergence_period: int
period (in number of steps) at which to test convergence
convergence_length: int
number of steps to use in testing convergence - this should be
large enough to measure the variance, but if it is too long
this will result in incorect early convergence tests
convergence_burnin_fraction: float [0, 1]
the fraction of the burn-in period after which to start testing
convergence_threshold_number: int
the number of consecutive times where the test passes after which
to break the burn-in and go to production
convergence_threshold: float
the threshold to use in diagnosing convergence. Gelman & Rubin
recomend a value of 1.2, 1.1 for strict convergence
convergence_prod_threshold: float
the threshold to test the production values with
convergence_plot_upper_lim: float
the upper limit to use in the diagnostic plot
convergence_early_stopping: bool
if true, stop the burnin early if convergence is reached
n: int
Number of steps after which to test convergence
test_type: str ['autocorr', 'GR']
If 'autocorr' use the exponential autocorrelation time (kwargs
passed to `get_autocorr_convergence`). If 'GR' use the Gelman-Rubin
statistic (kwargs passed to `get_GR_convergence`)
windowed: bool
If True, only calculate the convergence test in a window of length
`n`
"""
if convergence_length > convergence_period:
raise ValueError('convergence_length must be < convergence_period')
logging.info('Setting up convergence testing')
self.convergence_length = convergence_length
self.convergence_period = convergence_period
self.convergence_burnin_fraction = convergence_burnin_fraction
self.convergence_prod_threshold = convergence_prod_threshold
self.convergence_n = n
self.convergence_windowed = windowed
self.convergence_test_type = test_type
self.convergence_kwargs = kwargs
self.convergence_diagnostic = []
self.convergence_diagnosticx = []
self.convergence_threshold_number = convergence_threshold_number
self.convergence_threshold = convergence_threshold
self.convergence_number = 0
self.convergence_plot_upper_lim = convergence_plot_upper_lim
self.convergence_early_stopping = convergence_early_stopping
def _get_convergence_statistic(self, i, sampler):
s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
N = float(self.convergence_length)
if test_type in ['autocorr']:
self._get_convergence_test = self.test_autocorr_convergence
elif test_type in ['GR']:
self._get_convergence_test= self.test_GR_convergence
else:
raise ValueError('test_type {} not understood'.format(test_type))
def test_autocorr_convergence(self, i, sampler, test=True, n_cut=5):
try:
acors = np.zeros((self.ntemps, self.ndim))
for temp in range(self.ntemps):
if self.convergence_windowed:
j = i-self.convergence_n
else:
j = 0
x = np.mean(sampler.chain[temp, :, j:i, :], axis=0)
acors[temp, :] = emcee.autocorr.exponential_time(x)
c = np.max(acors, axis=0)
except emcee.autocorr.AutocorrError:
c = np.zeros(self.ndim) + np.nan
self.convergence_diagnosticx.append(i - self.convergence_n/2.)
self.convergence_diagnostic.append(list(c))
if test:
return i > n_cut * np.max(c)
def test_GR_convergence(self, i, sampler, test=True, R=1.1):
if self.convergence_windowed:
s = sampler.chain[0, :, i-self.convergence_n+1:i+1, :]
else:
s = sampler.chain[0, :, :i+1, :]
N = float(self.convergence_n)
M = float(self.nwalkers)
W = np.mean(np.var(s, axis=1), axis=0)
per_walker_mean = np.mean(s, axis=1)
......@@ -302,58 +304,45 @@ class MCMCSearch(core.BaseSearchClass):
Vhat = (N-1)/N * W + (M+1)/(M*N) * B
c = np.sqrt(Vhat/W)
self.convergence_diagnostic.append(c)
self.convergence_diagnosticx.append(i - self.convergence_length/2)
return c
self.convergence_diagnosticx.append(i - self.convergence_n/2.)
def _burnin_convergence_test(self, i, sampler, nburn):
if i < self.convergence_burnin_fraction*nburn:
return False
if np.mod(i+1, self.convergence_period) != 0:
if test and np.max(c) < R:
return True
else:
return False
c = self._get_convergence_statistic(i, sampler)
if np.all(c < self.convergence_threshold):
self.convergence_number += 1
def _test_convergence(self, i, sampler, **kwargs):
if np.mod(i+1, self.convergence_n) == 0:
return self._get_convergence_test(i, sampler, **kwargs)
else:
self.convergence_number = 0
if self.convergence_early_stopping:
return self.convergence_number > self.convergence_threshold_number
def _prod_convergence_test(self, i, sampler, nburn):
testA = i > nburn + self.convergence_length
testB = np.mod(i+1, self.convergence_period) == 0
if testA and testB:
self._get_convergence_statistic(i, sampler)
def _check_production_convergence(self, k):
bools = np.any(
np.array(self.convergence_diagnostic)[k:, :]
> self.convergence_prod_threshold, axis=1)
if np.any(bools):
logging.warning(
'{} convergence tests in the production run of {} failed'
.format(np.sum(bools), len(bools)))
return False
def _run_sampler_with_conv_test(self, sampler, p0, nprod=0, nburn=0):
logging.info('Running {} burn-in steps with convergence testing'
.format(nburn))
iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
for i, output in enumerate(iterator):
if self._test_convergence(i, sampler, test=True,
**self.convergence_kwargs):
logging.info(
'Converged at {} before max number {} of steps reached'
.format(i, nburn))
self.convergence_idx = i
break
iterator.close()
logging.info('Running {} production steps'.format(nprod))
j = nburn
iterator = tqdm(sampler.sample(output[0], iterations=nprod),
total=nprod)
for result in iterator:
self._test_convergence(j, sampler, test=False,
**self.convergence_kwargs)
j += 1
return sampler
def _run_sampler(self, sampler, p0, nprod=0, nburn=0):
if hasattr(self, 'convergence_period'):
logging.info('Running {} burn-in steps with convergence testing'
.format(nburn))
iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
for i, output in enumerate(iterator):
if self._burnin_convergence_test(i, sampler, nburn):
logging.info(
'Converged at {} before max number {} of steps reached'
.format(i, nburn))
self.convergence_idx = i
break
iterator.close()
logging.info('Running {} production steps'.format(nprod))
j = nburn
k = len(self.convergence_diagnostic)
for result in tqdm(sampler.sample(output[0], iterations=nprod),
total=nprod):
self._prod_convergence_test(j, sampler, nburn)
j += 1
self._check_production_convergence(k)
if hasattr(self, 'convergence_n'):
self._run_sampler_with_conv_test(sampler, p0, nprod, nburn)
else:
for result in tqdm(sampler.sample(p0, iterations=nburn+nprod),
total=nburn+nprod):
......@@ -956,9 +945,11 @@ class MCMCSearch(core.BaseSearchClass):
zorder=-10)
ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-C0',
zorder=-10)
ax.set_ylabel('PSRF')
if self.convergence_test_type == 'autocorr':
ax.set_ylabel(r'$\tau_\mathrm{exp}$')
elif self.convergence_test_type == 'GR':
ax.set_ylabel('PSRF')
ax.ticklabel_format(useOffset=False)
ax.set_ylim(0.5, self.convergence_plot_upper_lim)
else:
axes[0].ticklabel_format(useOffset=False, axis='y')
cs = chain[:, :, temp].T
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment