diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index acb466b64bac4aef587a6104a1b30ce1adfbc43d..22ed2dcc18f14ac115383b16e5f697df203fc299 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -10,6 +10,7 @@ import numpy as np import matplotlib import matplotlib.pyplot as plt import emcee +import pymc3 import corner import dill as pickle @@ -208,11 +209,69 @@ class MCMCSearch(BaseSearchClass): return p0 - def run_sampler_with_progress_bar(self, sampler, ns, p0): + def OLD_run_sampler_with_progress_bar(self, sampler, ns, p0): for result in tqdm(sampler.sample(p0, iterations=ns), total=ns): pass return sampler + #def run_sampler(self, sampler, ns, p0): + # convergence_period = 200 + # convergence_diagnostic = [] + # convergence_diagnosticx = [] + # for i, result in enumerate(tqdm( + # sampler.sample(p0, iterations=ns), total=ns)): + # if np.mod(i+1, convergence_period) == 0: + # s = sampler.chain[0, :, i-convergence_period+1:i+1, :] + # score_per_parameter = [] + # for j in range(self.ndim): + # scores = [] + # for k in range(self.nwalkers): + # out = pymc3.geweke( + # s[k, :, j].reshape((convergence_period)), + # intervals=2, first=0.4, last=0.4) + # scores.append(out[0][1]) + # score_per_parameter.append(np.median(scores)) + # convergence_diagnostic.append(score_per_parameter) + # convergence_diagnosticx.append(i - convergence_period/2) + # self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic)) + # self.convergence_diagnosticx = convergence_diagnosticx + # return sampler + + #def run_sampler(self, sampler, ns, p0): + # convergence_period = 200 + # convergence_diagnostic = [] + # convergence_diagnosticx = [] + # for i, result in enumerate(tqdm( + # sampler.sample(p0, iterations=ns), total=ns)): + # if np.mod(i+1, convergence_period) == 0: + # s = sampler.chain[0, :, i-convergence_period+1:i+1, :] + # mean_per_chain = np.mean(s, axis=1) + # std_per_chain = np.std(s, axis=1) + # mean = np.mean(mean_per_chain, axis=0) + # B = convergence_period * np.sum((mean_per_chain - mean)**2, axis=0) / (self.nwalkers - 1) + # W = np.sum(std_per_chain**2, axis=0) / self.nwalkers + # print B, W + # convergence_diagnostic.append(W/B) + # convergence_diagnosticx.append(i - convergence_period/2) + # self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic)) + # self.convergence_diagnosticx = convergence_diagnosticx + # return sampler + + def run_sampler(self, sampler, ns, p0): + convergence_period = 200 + convergence_diagnostic = [] + convergence_diagnosticx = [] + for i, result in enumerate(tqdm( + sampler.sample(p0, iterations=ns), total=ns)): + if np.mod(i+1, convergence_period) == 0: + s = sampler.chain[0, :, i-convergence_period+1:i+1, :] + Z = (s - np.mean(s, axis=(0, 1)))/np.std(s, axis=(0, 1)) + convergence_diagnostic.append(np.mean(Z, axis=(0, 1))) + convergence_diagnosticx.append(i - convergence_period/2) + self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic)) + self.convergence_diagnosticx = convergence_diagnosticx + return sampler + def run(self, proposal_scale_factor=2, create_plots=True, **kwargs): self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use() @@ -241,7 +300,7 @@ class MCMCSearch(BaseSearchClass): for j, n in enumerate(self.nsteps[:-2]): logging.info('Running {}/{} initialisation with {} steps'.format( j, ninit_steps, n)) - sampler = self.run_sampler_with_progress_bar(sampler, n, p0) + sampler = self.run_sampler(sampler, n, p0) logging.info("Mean acceptance fraction: {}" .format(np.mean(sampler.acceptance_fraction, axis=1))) if self.ntemps > 1: @@ -267,7 +326,7 @@ class MCMCSearch(BaseSearchClass): nprod = self.nsteps[-1] logging.info('Running final burn and prod with {} steps'.format( nburn+nprod)) - sampler = self.run_sampler_with_progress_bar(sampler, nburn+nprod, p0) + sampler = self.run_sampler(sampler, nburn+nprod, p0) logging.info("Mean acceptance fraction: {}" .format(np.mean(sampler.acceptance_fraction, axis=1))) if self.ntemps > 1: @@ -575,6 +634,10 @@ class MCMCSearch(BaseSearchClass): symbols[i]+'$-$'+symbols[i]+'$_0$', labelpad=labelpad) + if hasattr(self, 'convergence_diagnostic'): + ax = axes[i].twinx() + ax.plot(self.convergence_diagnosticx, + self.convergence_diagnostic[:, i], '-b') else: axes[0].ticklabel_format(useOffset=False, axis='y') cs = chain[:, :, temp].T @@ -623,7 +686,7 @@ class MCMCSearch(BaseSearchClass): axes[-1].set_xlim(minv-0.1*Range, maxv+0.1*Range) xfmt = matplotlib.ticker.ScalarFormatter() - xfmt.set_powerlimits((-4, 4)) + xfmt.set_powerlimits((-4, 4)) axes[-1].xaxis.set_major_formatter(xfmt) axes[-2].set_xlabel(r'$\textrm{Number of steps}$', labelpad=0.2) @@ -1593,8 +1656,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): logging.info(('Running {}/{} with {} steps and {} nsegs ' '(Tcoh={:1.2f} days)').format( j+1, len(run_setup), (nburn, nprod), nseg, Tcoh)) - sampler = self.run_sampler_with_progress_bar( - sampler, nburn+nprod, p0) + sampler = self.run_sampler(sampler, nburn+nprod, p0) logging.info("Mean acceptance fraction: {}" .format(np.mean(sampler.acceptance_fraction, axis=1))) if self.ntemps > 1: diff --git a/tests.py b/tests.py index b6e8fa09ee7fc0e3f50cf6c4b140690536f23fcb..bcf7ec8ec4f51d264a16d79338f4c81008789e0d 100644 --- a/tests.py +++ b/tests.py @@ -192,7 +192,7 @@ class TestMCMCSearch(Test): label = "Test" def test_fully_coherent(self): - h0 = 1e-24 + h0 = 1e-27 sqrtSX = 1e-22 F0 = 30 F1 = -1e-10 @@ -203,27 +203,26 @@ class TestMCMCSearch(Test): Alpha = 5e-3 Delta = 1.2 tref = minStartTime - dtglitch = None delta_F0 = 0 Writer = pyfstat.Writer(F0=F0, F1=F1, F2=F2, label=self.label, h0=h0, sqrtSX=sqrtSX, outdir=outdir, tstart=minStartTime, Alpha=Alpha, Delta=Delta, tref=tref, - duration=duration, dtglitch=dtglitch, + duration=duration, delta_F0=delta_F0, Band=4) Writer.make_data() predicted_FS = Writer.predict_fstat() - theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)}, - 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-9*F1)}, + theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-8*F0)}, + 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-3*F1)}, 'F2': F2, 'Alpha': Alpha, 'Delta': Delta} search = pyfstat.MCMCSearch( label=self.label, outdir=outdir, theta_prior=theta, tref=tref, sftfilepath='{}/*{}*sft'.format(Writer.outdir, Writer.label), minStartTime=minStartTime, maxStartTime=maxStartTime, - nsteps=[100, 100], nwalkers=100, ntemps=1) + nsteps=[500, 100], nwalkers=100, ntemps=2) search.run() search.plot_corner(add_prior=True) _, FS = search.get_max_twoF()