Commit 62686f20 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Various rough changes to add convergance checking

parent 4966dc2c
......@@ -10,6 +10,7 @@ import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import emcee
import pymc3
import corner
import dill as pickle
......@@ -208,11 +209,69 @@ class MCMCSearch(BaseSearchClass):
return p0
def run_sampler_with_progress_bar(self, sampler, ns, p0):
def OLD_run_sampler_with_progress_bar(self, sampler, ns, p0):
for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
pass
return sampler
#def run_sampler(self, sampler, ns, p0):
# convergence_period = 200
# convergence_diagnostic = []
# convergence_diagnosticx = []
# for i, result in enumerate(tqdm(
# sampler.sample(p0, iterations=ns), total=ns)):
# if np.mod(i+1, convergence_period) == 0:
# s = sampler.chain[0, :, i-convergence_period+1:i+1, :]
# score_per_parameter = []
# for j in range(self.ndim):
# scores = []
# for k in range(self.nwalkers):
# out = pymc3.geweke(
# s[k, :, j].reshape((convergence_period)),
# intervals=2, first=0.4, last=0.4)
# scores.append(out[0][1])
# score_per_parameter.append(np.median(scores))
# convergence_diagnostic.append(score_per_parameter)
# convergence_diagnosticx.append(i - convergence_period/2)
# self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic))
# self.convergence_diagnosticx = convergence_diagnosticx
# return sampler
#def run_sampler(self, sampler, ns, p0):
# convergence_period = 200
# convergence_diagnostic = []
# convergence_diagnosticx = []
# for i, result in enumerate(tqdm(
# sampler.sample(p0, iterations=ns), total=ns)):
# if np.mod(i+1, convergence_period) == 0:
# s = sampler.chain[0, :, i-convergence_period+1:i+1, :]
# mean_per_chain = np.mean(s, axis=1)
# std_per_chain = np.std(s, axis=1)
# mean = np.mean(mean_per_chain, axis=0)
# B = convergence_period * np.sum((mean_per_chain - mean)**2, axis=0) / (self.nwalkers - 1)
# W = np.sum(std_per_chain**2, axis=0) / self.nwalkers
# print B, W
# convergence_diagnostic.append(W/B)
# convergence_diagnosticx.append(i - convergence_period/2)
# self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic))
# self.convergence_diagnosticx = convergence_diagnosticx
# return sampler
def run_sampler(self, sampler, ns, p0):
convergence_period = 200
convergence_diagnostic = []
convergence_diagnosticx = []
for i, result in enumerate(tqdm(
sampler.sample(p0, iterations=ns), total=ns)):
if np.mod(i+1, convergence_period) == 0:
s = sampler.chain[0, :, i-convergence_period+1:i+1, :]
Z = (s - np.mean(s, axis=(0, 1)))/np.std(s, axis=(0, 1))
convergence_diagnostic.append(np.mean(Z, axis=(0, 1)))
convergence_diagnosticx.append(i - convergence_period/2)
self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic))
self.convergence_diagnosticx = convergence_diagnosticx
return sampler
def run(self, proposal_scale_factor=2, create_plots=True, **kwargs):
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
......@@ -241,7 +300,7 @@ class MCMCSearch(BaseSearchClass):
for j, n in enumerate(self.nsteps[:-2]):
logging.info('Running {}/{} initialisation with {} steps'.format(
j, ninit_steps, n))
sampler = self.run_sampler_with_progress_bar(sampler, n, p0)
sampler = self.run_sampler(sampler, n, p0)
logging.info("Mean acceptance fraction: {}"
.format(np.mean(sampler.acceptance_fraction, axis=1)))
if self.ntemps > 1:
......@@ -267,7 +326,7 @@ class MCMCSearch(BaseSearchClass):
nprod = self.nsteps[-1]
logging.info('Running final burn and prod with {} steps'.format(
nburn+nprod))
sampler = self.run_sampler_with_progress_bar(sampler, nburn+nprod, p0)
sampler = self.run_sampler(sampler, nburn+nprod, p0)
logging.info("Mean acceptance fraction: {}"
.format(np.mean(sampler.acceptance_fraction, axis=1)))
if self.ntemps > 1:
......@@ -575,6 +634,10 @@ class MCMCSearch(BaseSearchClass):
symbols[i]+'$-$'+symbols[i]+'$_0$',
labelpad=labelpad)
if hasattr(self, 'convergence_diagnostic'):
ax = axes[i].twinx()
ax.plot(self.convergence_diagnosticx,
self.convergence_diagnostic[:, i], '-b')
else:
axes[0].ticklabel_format(useOffset=False, axis='y')
cs = chain[:, :, temp].T
......@@ -623,7 +686,7 @@ class MCMCSearch(BaseSearchClass):
axes[-1].set_xlim(minv-0.1*Range, maxv+0.1*Range)
xfmt = matplotlib.ticker.ScalarFormatter()
xfmt.set_powerlimits((-4, 4))
xfmt.set_powerlimits((-4, 4))
axes[-1].xaxis.set_major_formatter(xfmt)
axes[-2].set_xlabel(r'$\textrm{Number of steps}$', labelpad=0.2)
......@@ -1593,8 +1656,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
logging.info(('Running {}/{} with {} steps and {} nsegs '
'(Tcoh={:1.2f} days)').format(
j+1, len(run_setup), (nburn, nprod), nseg, Tcoh))
sampler = self.run_sampler_with_progress_bar(
sampler, nburn+nprod, p0)
sampler = self.run_sampler(sampler, nburn+nprod, p0)
logging.info("Mean acceptance fraction: {}"
.format(np.mean(sampler.acceptance_fraction, axis=1)))
if self.ntemps > 1:
......
......@@ -192,7 +192,7 @@ class TestMCMCSearch(Test):
label = "Test"
def test_fully_coherent(self):
h0 = 1e-24
h0 = 1e-27
sqrtSX = 1e-22
F0 = 30
F1 = -1e-10
......@@ -203,27 +203,26 @@ class TestMCMCSearch(Test):
Alpha = 5e-3
Delta = 1.2
tref = minStartTime
dtglitch = None
delta_F0 = 0
Writer = pyfstat.Writer(F0=F0, F1=F1, F2=F2, label=self.label,
h0=h0, sqrtSX=sqrtSX,
outdir=outdir, tstart=minStartTime,
Alpha=Alpha, Delta=Delta, tref=tref,
duration=duration, dtglitch=dtglitch,
duration=duration,
delta_F0=delta_F0, Band=4)
Writer.make_data()
predicted_FS = Writer.predict_fstat()
theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)},
'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-9*F1)},
theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-8*F0)},
'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-3*F1)},
'F2': F2, 'Alpha': Alpha, 'Delta': Delta}
search = pyfstat.MCMCSearch(
label=self.label, outdir=outdir, theta_prior=theta, tref=tref,
sftfilepath='{}/*{}*sft'.format(Writer.outdir, Writer.label),
minStartTime=minStartTime, maxStartTime=maxStartTime,
nsteps=[100, 100], nwalkers=100, ntemps=1)
nsteps=[500, 100], nwalkers=100, ntemps=2)
search.run()
search.plot_corner(add_prior=True)
_, FS = search.get_max_twoF()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment