diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index 22ed2dcc18f14ac115383b16e5f697df203fc299..d2759dbc1a68a9760bfa3a06107078ce502e1a61 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -214,63 +214,57 @@ class MCMCSearch(BaseSearchClass): pass return sampler - #def run_sampler(self, sampler, ns, p0): - # convergence_period = 200 - # convergence_diagnostic = [] - # convergence_diagnosticx = [] - # for i, result in enumerate(tqdm( - # sampler.sample(p0, iterations=ns), total=ns)): - # if np.mod(i+1, convergence_period) == 0: - # s = sampler.chain[0, :, i-convergence_period+1:i+1, :] - # score_per_parameter = [] - # for j in range(self.ndim): - # scores = [] - # for k in range(self.nwalkers): - # out = pymc3.geweke( - # s[k, :, j].reshape((convergence_period)), - # intervals=2, first=0.4, last=0.4) - # scores.append(out[0][1]) - # score_per_parameter.append(np.median(scores)) - # convergence_diagnostic.append(score_per_parameter) - # convergence_diagnosticx.append(i - convergence_period/2) - # self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic)) - # self.convergence_diagnosticx = convergence_diagnosticx - # return sampler - - #def run_sampler(self, sampler, ns, p0): - # convergence_period = 200 - # convergence_diagnostic = [] - # convergence_diagnosticx = [] - # for i, result in enumerate(tqdm( - # sampler.sample(p0, iterations=ns), total=ns)): - # if np.mod(i+1, convergence_period) == 0: - # s = sampler.chain[0, :, i-convergence_period+1:i+1, :] - # mean_per_chain = np.mean(s, axis=1) - # std_per_chain = np.std(s, axis=1) - # mean = np.mean(mean_per_chain, axis=0) - # B = convergence_period * np.sum((mean_per_chain - mean)**2, axis=0) / (self.nwalkers - 1) - # W = np.sum(std_per_chain**2, axis=0) / self.nwalkers - # print B, W - # convergence_diagnostic.append(W/B) - # convergence_diagnosticx.append(i - convergence_period/2) - # self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic)) - # self.convergence_diagnosticx = convergence_diagnosticx - # return sampler - - def run_sampler(self, sampler, ns, p0): - convergence_period = 200 - convergence_diagnostic = [] - convergence_diagnosticx = [] - for i, result in enumerate(tqdm( - sampler.sample(p0, iterations=ns), total=ns)): - if np.mod(i+1, convergence_period) == 0: - s = sampler.chain[0, :, i-convergence_period+1:i+1, :] - Z = (s - np.mean(s, axis=(0, 1)))/np.std(s, axis=(0, 1)) - convergence_diagnostic.append(np.mean(Z, axis=(0, 1))) - convergence_diagnosticx.append(i - convergence_period/2) - self.convergence_diagnostic = np.array(np.abs(convergence_diagnostic)) - self.convergence_diagnosticx = convergence_diagnosticx - return sampler + def setup_convergence_testing( + self, convergence_period=10, convergence_length=10, + convergence_burnin_fraction=0.25, convergence_threshold_number=5, + convergence_threshold=1.2): + + if convergence_length > convergence_period: + raise ValueError('convergence_length must be < convergence_period') + logging.info('Setting up convergence testing') + self.convergence_length = convergence_length + self.convergence_period = convergence_period + self.convergence_burnin_fraction = convergence_burnin_fraction + self.convergence_diagnostic = [] + self.convergence_diagnosticx = [] + self.convergence_threshold_number = convergence_threshold_number + self.convergence_threshold = convergence_threshold + self.convergence_number = 0 + + def convergence_test(self, i, sampler, nburn): + if i < self.convergence_burnin_fraction*nburn: + return False + if np.mod(i+1, self.convergence_period) == 0: + return False + s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :] + within_std = np.mean(np.var(s, axis=1), axis=0) + per_walker_mean = np.mean(s, axis=1) + mean = np.mean(per_walker_mean, axis=0) + between_std = np.sqrt(np.mean((per_walker_mean-mean)**2, axis=0)) + W = within_std + B_over_n = between_std**2 / self.convergence_period + Vhat = ((self.convergence_period-1.)/self.convergence_period * W + + B_over_n + B_over_n / float(self.nwalkers)) + c = Vhat/W + self.convergence_diagnostic.append(c) + self.convergence_diagnosticx.append(i - self.convergence_period/2) + if np.all(c < self.convergence_threshold): + self.convergence_number += 1 + + return self.convergence_number > self.convergence_threshold_number + + def run_sampler(self, sampler, p0, nprod=0, nburn=0): + if hasattr(self, 'convergence_period'): + for i, result in enumerate(tqdm( + sampler.sample(p0, iterations=nburn+nprod), + total=nburn+nprod)): + converged = self.convergence_test(i, sampler, nburn) + return sampler + else: + for result in tqdm(sampler.sample(p0, iterations=nburn+nprod), + total=nburn+nprod): + pass + return sampler def run(self, proposal_scale_factor=2, create_plots=True, **kwargs): @@ -300,7 +294,7 @@ class MCMCSearch(BaseSearchClass): for j, n in enumerate(self.nsteps[:-2]): logging.info('Running {}/{} initialisation with {} steps'.format( j, ninit_steps, n)) - sampler = self.run_sampler(sampler, n, p0) + sampler = self.run_sampler(sampler, p0, nburn=n) logging.info("Mean acceptance fraction: {}" .format(np.mean(sampler.acceptance_fraction, axis=1))) if self.ntemps > 1: @@ -326,7 +320,7 @@ class MCMCSearch(BaseSearchClass): nprod = self.nsteps[-1] logging.info('Running final burn and prod with {} steps'.format( nburn+nprod)) - sampler = self.run_sampler(sampler, nburn+nprod, p0) + sampler = self.run_sampler(sampler, p0, nburn=nburn, nprod=nprod) logging.info("Mean acceptance fraction: {}" .format(np.mean(sampler.acceptance_fraction, axis=1))) if self.ntemps > 1: @@ -636,8 +630,11 @@ class MCMCSearch(BaseSearchClass): if hasattr(self, 'convergence_diagnostic'): ax = axes[i].twinx() - ax.plot(self.convergence_diagnosticx, - self.convergence_diagnostic[:, i], '-b') + c_x = np.array(self.convergence_diagnosticx) + c_y = np.array(self.convergence_diagnostic) + ax.plot(c_x, c_y[:, i], '-b') + ax.ticklabel_format(useOffset=False) + ax.set_ylim(1, 5) else: axes[0].ticklabel_format(useOffset=False, axis='y') cs = chain[:, :, temp].T @@ -1656,7 +1653,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): logging.info(('Running {}/{} with {} steps and {} nsegs ' '(Tcoh={:1.2f} days)').format( j+1, len(run_setup), (nburn, nprod), nseg, Tcoh)) - sampler = self.run_sampler(sampler, nburn+nprod, p0) + sampler = self.run_sampler(sampler, p0, nburn=nburn, nprod=nprod) logging.info("Mean acceptance fraction: {}" .format(np.mean(sampler.acceptance_fraction, axis=1))) if self.ntemps > 1: diff --git a/tests.py b/tests.py index bcf7ec8ec4f51d264a16d79338f4c81008789e0d..30cc5e798e337fd1af2eff25ae956412b2951780 100644 --- a/tests.py +++ b/tests.py @@ -192,7 +192,7 @@ class TestMCMCSearch(Test): label = "Test" def test_fully_coherent(self): - h0 = 1e-27 + h0 = 1e-24 sqrtSX = 1e-22 F0 = 30 F1 = -1e-10 @@ -214,7 +214,7 @@ class TestMCMCSearch(Test): Writer.make_data() predicted_FS = Writer.predict_fstat() - theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-8*F0)}, + theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-7*F0)}, 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-3*F1)}, 'F2': F2, 'Alpha': Alpha, 'Delta': Delta} @@ -222,8 +222,9 @@ class TestMCMCSearch(Test): label=self.label, outdir=outdir, theta_prior=theta, tref=tref, sftfilepath='{}/*{}*sft'.format(Writer.outdir, Writer.label), minStartTime=minStartTime, maxStartTime=maxStartTime, - nsteps=[500, 100], nwalkers=100, ntemps=2) - search.run() + nsteps=[500, 100], nwalkers=100, ntemps=2, log10temperature_min=-1) + search.setup_convergence_testing() + search.run(create_plots=True) search.plot_corner(add_prior=True) _, FS = search.get_max_twoF()