Commit 35ef58e3 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adds convergence checking in a complete form

Convergence testing based on the Rubin-Gelman statistic is implemented
via `setup_convergence_test`. This does not truncate the search yet, but
simply adds a trace to the walkers plots for testing
parent 62686f20
...@@ -214,63 +214,57 @@ class MCMCSearch(BaseSearchClass): ...@@ -214,63 +214,57 @@ class MCMCSearch(BaseSearchClass):
pass pass
return sampler return sampler
#def run_sampler(self, sampler, ns, p0): def setup_convergence_testing(
# convergence_period = 200 self, convergence_period=10, convergence_length=10,
# convergence_diagnostic = [] convergence_burnin_fraction=0.25, convergence_threshold_number=5,
# convergence_diagnosticx = [] convergence_threshold=1.2):
# for i, result in enumerate(tqdm(
# sampler.sample(p0, iterations=ns), total=ns)): if convergence_length > convergence_period:
# if np.mod(i+1, convergence_period) == 0: raise ValueError('convergence_length must be < convergence_period')
# s = sampler.chain[0, :, i-convergence_period+1:i+1, :] logging.info('Setting up convergence testing')
# score_per_parameter = [] self.convergence_length = convergence_length
# for j in range(self.ndim): self.convergence_period = convergence_period
# scores = [] self.convergence_burnin_fraction = convergence_burnin_fraction
# for k in range(self.nwalkers): self.convergence_diagnostic = []
# out = pymc3.geweke( self.convergence_diagnosticx = []
# s[k, :, j].reshape((convergence_period)), self.convergence_threshold_number = convergence_threshold_number
# intervals=2, first=0.4, last=0.4) self.convergence_threshold = convergence_threshold
# scores.append(out[0][1]) self.convergence_number = 0
# score_per_parameter.append(np.median(scores))
# convergence_diagnostic.append(score_per_parameter) def convergence_test(self, i, sampler, nburn):
# convergence_diagnosticx.append(i - convergence_period/2) if i < self.convergence_burnin_fraction*nburn:
# self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic)) return False
# self.convergence_diagnosticx = convergence_diagnosticx if np.mod(i+1, self.convergence_period) == 0:
# return sampler return False
s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
#def run_sampler(self, sampler, ns, p0): within_std = np.mean(np.var(s, axis=1), axis=0)
# convergence_period = 200 per_walker_mean = np.mean(s, axis=1)
# convergence_diagnostic = [] mean = np.mean(per_walker_mean, axis=0)
# convergence_diagnosticx = [] between_std = np.sqrt(np.mean((per_walker_mean-mean)**2, axis=0))
# for i, result in enumerate(tqdm( W = within_std
# sampler.sample(p0, iterations=ns), total=ns)): B_over_n = between_std**2 / self.convergence_period
# if np.mod(i+1, convergence_period) == 0: Vhat = ((self.convergence_period-1.)/self.convergence_period * W
# s = sampler.chain[0, :, i-convergence_period+1:i+1, :] + B_over_n + B_over_n / float(self.nwalkers))
# mean_per_chain = np.mean(s, axis=1) c = Vhat/W
# std_per_chain = np.std(s, axis=1) self.convergence_diagnostic.append(c)
# mean = np.mean(mean_per_chain, axis=0) self.convergence_diagnosticx.append(i - self.convergence_period/2)
# B = convergence_period * np.sum((mean_per_chain - mean)**2, axis=0) / (self.nwalkers - 1) if np.all(c < self.convergence_threshold):
# W = np.sum(std_per_chain**2, axis=0) / self.nwalkers self.convergence_number += 1
# print B, W
# convergence_diagnostic.append(W/B) return self.convergence_number > self.convergence_threshold_number
# convergence_diagnosticx.append(i - convergence_period/2)
# self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic)) def run_sampler(self, sampler, p0, nprod=0, nburn=0):
# self.convergence_diagnosticx = convergence_diagnosticx if hasattr(self, 'convergence_period'):
# return sampler for i, result in enumerate(tqdm(
sampler.sample(p0, iterations=nburn+nprod),
def run_sampler(self, sampler, ns, p0): total=nburn+nprod)):
convergence_period = 200 converged = self.convergence_test(i, sampler, nburn)
convergence_diagnostic = [] return sampler
convergence_diagnosticx = [] else:
for i, result in enumerate(tqdm( for result in tqdm(sampler.sample(p0, iterations=nburn+nprod),
sampler.sample(p0, iterations=ns), total=ns)): total=nburn+nprod):
if np.mod(i+1, convergence_period) == 0: pass
s = sampler.chain[0, :, i-convergence_period+1:i+1, :] return sampler
Z = (s - np.mean(s, axis=(0, 1)))/np.std(s, axis=(0, 1))
convergence_diagnostic.append(np.mean(Z, axis=(0, 1)))
convergence_diagnosticx.append(i - convergence_period/2)
self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic))
self.convergence_diagnosticx = convergence_diagnosticx
return sampler
def run(self, proposal_scale_factor=2, create_plots=True, **kwargs): def run(self, proposal_scale_factor=2, create_plots=True, **kwargs):
...@@ -300,7 +294,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -300,7 +294,7 @@ class MCMCSearch(BaseSearchClass):
for j, n in enumerate(self.nsteps[:-2]): for j, n in enumerate(self.nsteps[:-2]):
logging.info('Running {}/{} initialisation with {} steps'.format( logging.info('Running {}/{} initialisation with {} steps'.format(
j, ninit_steps, n)) j, ninit_steps, n))
sampler = self.run_sampler(sampler, n, p0) sampler = self.run_sampler(sampler, p0, nburn=n)
logging.info("Mean acceptance fraction: {}" logging.info("Mean acceptance fraction: {}"
.format(np.mean(sampler.acceptance_fraction, axis=1))) .format(np.mean(sampler.acceptance_fraction, axis=1)))
if self.ntemps > 1: if self.ntemps > 1:
...@@ -326,7 +320,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -326,7 +320,7 @@ class MCMCSearch(BaseSearchClass):
nprod = self.nsteps[-1] nprod = self.nsteps[-1]
logging.info('Running final burn and prod with {} steps'.format( logging.info('Running final burn and prod with {} steps'.format(
nburn+nprod)) nburn+nprod))
sampler = self.run_sampler(sampler, nburn+nprod, p0) sampler = self.run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
logging.info("Mean acceptance fraction: {}" logging.info("Mean acceptance fraction: {}"
.format(np.mean(sampler.acceptance_fraction, axis=1))) .format(np.mean(sampler.acceptance_fraction, axis=1)))
if self.ntemps > 1: if self.ntemps > 1:
...@@ -636,8 +630,11 @@ class MCMCSearch(BaseSearchClass): ...@@ -636,8 +630,11 @@ class MCMCSearch(BaseSearchClass):
if hasattr(self, 'convergence_diagnostic'): if hasattr(self, 'convergence_diagnostic'):
ax = axes[i].twinx() ax = axes[i].twinx()
ax.plot(self.convergence_diagnosticx, c_x = np.array(self.convergence_diagnosticx)
self.convergence_diagnostic[:, i], '-b') c_y = np.array(self.convergence_diagnostic)
ax.plot(c_x, c_y[:, i], '-b')
ax.ticklabel_format(useOffset=False)
ax.set_ylim(1, 5)
else: else:
axes[0].ticklabel_format(useOffset=False, axis='y') axes[0].ticklabel_format(useOffset=False, axis='y')
cs = chain[:, :, temp].T cs = chain[:, :, temp].T
...@@ -1656,7 +1653,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): ...@@ -1656,7 +1653,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
logging.info(('Running {}/{} with {} steps and {} nsegs ' logging.info(('Running {}/{} with {} steps and {} nsegs '
'(Tcoh={:1.2f} days)').format( '(Tcoh={:1.2f} days)').format(
j+1, len(run_setup), (nburn, nprod), nseg, Tcoh)) j+1, len(run_setup), (nburn, nprod), nseg, Tcoh))
sampler = self.run_sampler(sampler, nburn+nprod, p0) sampler = self.run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
logging.info("Mean acceptance fraction: {}" logging.info("Mean acceptance fraction: {}"
.format(np.mean(sampler.acceptance_fraction, axis=1))) .format(np.mean(sampler.acceptance_fraction, axis=1)))
if self.ntemps > 1: if self.ntemps > 1:
......
...@@ -192,7 +192,7 @@ class TestMCMCSearch(Test): ...@@ -192,7 +192,7 @@ class TestMCMCSearch(Test):
label = "Test" label = "Test"
def test_fully_coherent(self): def test_fully_coherent(self):
h0 = 1e-27 h0 = 1e-24
sqrtSX = 1e-22 sqrtSX = 1e-22
F0 = 30 F0 = 30
F1 = -1e-10 F1 = -1e-10
...@@ -214,7 +214,7 @@ class TestMCMCSearch(Test): ...@@ -214,7 +214,7 @@ class TestMCMCSearch(Test):
Writer.make_data() Writer.make_data()
predicted_FS = Writer.predict_fstat() predicted_FS = Writer.predict_fstat()
theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-8*F0)}, theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-7*F0)},
'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-3*F1)}, 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-3*F1)},
'F2': F2, 'Alpha': Alpha, 'Delta': Delta} 'F2': F2, 'Alpha': Alpha, 'Delta': Delta}
...@@ -222,8 +222,9 @@ class TestMCMCSearch(Test): ...@@ -222,8 +222,9 @@ class TestMCMCSearch(Test):
label=self.label, outdir=outdir, theta_prior=theta, tref=tref, label=self.label, outdir=outdir, theta_prior=theta, tref=tref,
sftfilepath='{}/*{}*sft'.format(Writer.outdir, Writer.label), sftfilepath='{}/*{}*sft'.format(Writer.outdir, Writer.label),
minStartTime=minStartTime, maxStartTime=maxStartTime, minStartTime=minStartTime, maxStartTime=maxStartTime,
nsteps=[500, 100], nwalkers=100, ntemps=2) nsteps=[500, 100], nwalkers=100, ntemps=2, log10temperature_min=-1)
search.run() search.setup_convergence_testing()
search.run(create_plots=True)
search.plot_corner(add_prior=True) search.plot_corner(add_prior=True)
_, FS = search.get_max_twoF() _, FS = search.get_max_twoF()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment