Skip to content
Snippets Groups Projects
Commit 505b0ce3 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Completes adding convergence testing

- restructures the method into several different functions.
- implements stopping if convergence is met a prespecified number of times
- adds documentation
- will plot the convgerence on the RH axis of the walkers plots
parent 35ef58e3
No related branches found
No related tags found
No related merge requests found
......@@ -216,8 +216,35 @@ class MCMCSearch(BaseSearchClass):
def setup_convergence_testing(
self, convergence_period=10, convergence_length=10,
convergence_burnin_fraction=0.25, convergence_threshold_number=5,
convergence_threshold=1.2):
convergence_burnin_fraction=0.25, convergence_threshold_number=10,
convergence_threshold=1.2, convergence_prod_threshold=2):
"""
If called, convergence testing is used during the MCMC simulation
This uses the Gelmanr-Rubin statistic based on the ratio of between and
within walkers variance. The original statistic was developed for
multiple (independent) MCMC simulations, in this context we simply use
the walkers
Parameters
----------
convergence_period: int
period (in number of steps) at which to test convergence
convergence_length: int
number of steps to use in testing convergence - this should be
large enough to measure the variance, but if it is too long
this will result in incorect early convergence tests
convergence_burnin_fraction: float [0, 1]
the fraction of the burn-in period after which to start testing
convergence_threshold_number: int
the number of consecutive times where the test passes after which
to break the burn-in and go to production
convergence_threshold: float
the threshold to use in diagnosing convergence. Gelman & Rubin
recomend a value of 1.2, 1.1 for strict convergence
convergence_prod_threshold: float
the threshold to test the production values with
"""
if convergence_length > convergence_period:
raise ValueError('convergence_length must be < convergence_period')
......@@ -225,17 +252,14 @@ class MCMCSearch(BaseSearchClass):
self.convergence_length = convergence_length
self.convergence_period = convergence_period
self.convergence_burnin_fraction = convergence_burnin_fraction
self.convergence_prod_threshold = convergence_prod_threshold
self.convergence_diagnostic = []
self.convergence_diagnosticx = []
self.convergence_threshold_number = convergence_threshold_number
self.convergence_threshold = convergence_threshold
self.convergence_number = 0
def convergence_test(self, i, sampler, nburn):
if i < self.convergence_burnin_fraction*nburn:
return False
if np.mod(i+1, self.convergence_period) == 0:
return False
def get_convergence_statistic(self, i, sampler):
s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
within_std = np.mean(np.var(s, axis=1), axis=0)
per_walker_mean = np.mean(s, axis=1)
......@@ -248,17 +272,53 @@ class MCMCSearch(BaseSearchClass):
c = Vhat/W
self.convergence_diagnostic.append(c)
self.convergence_diagnosticx.append(i - self.convergence_period/2)
return c
def convergence_test(self, i, sampler, nburn):
if i < self.convergence_burnin_fraction*nburn:
return False
if np.mod(i+1, self.convergence_period) == 0:
return False
c = self.get_convergence_statistic(i, sampler)
if np.all(c < self.convergence_threshold):
self.convergence_number += 1
else:
self.convergence_number = 0
return self.convergence_number > self.convergence_threshold_number
def check_production_convergence(self, k):
bools = np.any(
np.array(self.convergence_diagnostic)[k:, :]
> self.convergence_prod_threshold, axis=1)
if np.any(bools):
logging.warning(
'{} convergence tests in the production run of {} failed'
.format(np.sum(bools), len(bools)))
def run_sampler(self, sampler, p0, nprod=0, nburn=0):
if hasattr(self, 'convergence_period'):
for i, result in enumerate(tqdm(
sampler.sample(p0, iterations=nburn+nprod),
total=nburn+nprod)):
converged = False
logging.info('Running {} burn-in steps with convergence testing'
.format(nburn))
iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
for i, output in enumerate(iterator):
if converged:
logging.info(
'Converged at {} before max number {} of steps reached'
.format(i, nburn))
self.convergence_idx = i
break
else:
converged = self.convergence_test(i, sampler, nburn)
iterator.close()
logging.info('Running {} production steps'.format(nprod))
j = nburn
k = len(self.convergence_diagnostic)
for result in tqdm(sampler.sample(output[0], iterations=nprod),
total=nprod):
self.get_convergence_statistic(j, sampler)
j += 1
self.check_production_convergence(k)
return sampler
else:
for result in tqdm(sampler.sample(p0, iterations=nburn+nprod),
......@@ -329,7 +389,7 @@ class MCMCSearch(BaseSearchClass):
if create_plots:
fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
burnin_idx=nburn, **kwargs)
nprod=nprod, **kwargs)
fig.tight_layout()
fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
dpi=200)
......@@ -572,7 +632,7 @@ class MCMCSearch(BaseSearchClass):
raise ValueError("dist_type {} unknown".format(dist_type))
def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
lw=0.1, burnin_idx=None, add_det_stat_burnin=False,
lw=0.1, nprod=None, add_det_stat_burnin=False,
fig=None, axes=None, xoffset=0, plot_det_stat=True,
context='classic', subtractions=None, labelpad=0.05):
""" Plot all the chains from a sampler """
......@@ -608,13 +668,18 @@ class MCMCSearch(BaseSearchClass):
for i in range(2, ndim+1)]
idxs = np.arange(chain.shape[1])
burnin_idx = chain.shape[1] - nprod
if hasattr(self, 'convergence_idx'):
convergence_idx = self.convergence_idx
else:
convergence_idx = burnin_idx
if ndim > 1:
for i in range(ndim):
axes[i].ticklabel_format(useOffset=False, axis='y')
cs = chain[:, :, i].T
if burnin_idx:
axes[i].plot(xoffset+idxs[:burnin_idx],
cs[:burnin_idx]-subtractions[i],
if burnin_idx > 0:
axes[i].plot(xoffset+idxs[:convergence_idx],
cs[:convergence_idx]-subtractions[i],
color="r", alpha=alpha,
lw=lw)
axes[i].plot(xoffset+idxs[burnin_idx:],
......@@ -1665,7 +1730,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
if create_plots:
fig, axes = self.plot_walkers(
sampler, symbols=self.theta_symbols, fig=fig, axes=axes,
burnin_idx=nburn, xoffset=nsteps_total, **kwargs)
nprod=nprod, xoffset=nsteps_total, **kwargs)
for ax in axes[:self.ndim]:
ax.axvline(nsteps_total, color='k', ls='--', lw=0.25)
......
......@@ -214,15 +214,15 @@ class TestMCMCSearch(Test):
Writer.make_data()
predicted_FS = Writer.predict_fstat()
theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-7*F0)},
'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-3*F1)},
theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)},
'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-9*F1)},
'F2': F2, 'Alpha': Alpha, 'Delta': Delta}
search = pyfstat.MCMCSearch(
label=self.label, outdir=outdir, theta_prior=theta, tref=tref,
sftfilepath='{}/*{}*sft'.format(Writer.outdir, Writer.label),
minStartTime=minStartTime, maxStartTime=maxStartTime,
nsteps=[500, 100], nwalkers=100, ntemps=2, log10temperature_min=-1)
nsteps=[100, 100], nwalkers=100, ntemps=2, log10temperature_min=-1)
search.setup_convergence_testing()
search.run(create_plots=True)
search.plot_corner(add_prior=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment