diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index d2759dbc1a68a9760bfa3a06107078ce502e1a61..9d9b9ad02ae8748a40c7f4820156abd6df5eb509 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -216,8 +216,35 @@ class MCMCSearch(BaseSearchClass): def setup_convergence_testing( self, convergence_period=10, convergence_length=10, - convergence_burnin_fraction=0.25, convergence_threshold_number=5, - convergence_threshold=1.2): + convergence_burnin_fraction=0.25, convergence_threshold_number=10, + convergence_threshold=1.2, convergence_prod_threshold=2): + """ + If called, convergence testing is used during the MCMC simulation + + This uses the Gelmanr-Rubin statistic based on the ratio of between and + within walkers variance. The original statistic was developed for + multiple (independent) MCMC simulations, in this context we simply use + the walkers + + Parameters + ---------- + convergence_period: int + period (in number of steps) at which to test convergence + convergence_length: int + number of steps to use in testing convergence - this should be + large enough to measure the variance, but if it is too long + this will result in incorect early convergence tests + convergence_burnin_fraction: float [0, 1] + the fraction of the burn-in period after which to start testing + convergence_threshold_number: int + the number of consecutive times where the test passes after which + to break the burn-in and go to production + convergence_threshold: float + the threshold to use in diagnosing convergence. Gelman & Rubin + recomend a value of 1.2, 1.1 for strict convergence + convergence_prod_threshold: float + the threshold to test the production values with + """ if convergence_length > convergence_period: raise ValueError('convergence_length must be < convergence_period') @@ -225,17 +252,14 @@ class MCMCSearch(BaseSearchClass): self.convergence_length = convergence_length self.convergence_period = convergence_period self.convergence_burnin_fraction = convergence_burnin_fraction + self.convergence_prod_threshold = convergence_prod_threshold self.convergence_diagnostic = [] self.convergence_diagnosticx = [] self.convergence_threshold_number = convergence_threshold_number self.convergence_threshold = convergence_threshold self.convergence_number = 0 - def convergence_test(self, i, sampler, nburn): - if i < self.convergence_burnin_fraction*nburn: - return False - if np.mod(i+1, self.convergence_period) == 0: - return False + def get_convergence_statistic(self, i, sampler): s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :] within_std = np.mean(np.var(s, axis=1), axis=0) per_walker_mean = np.mean(s, axis=1) @@ -248,17 +272,53 @@ class MCMCSearch(BaseSearchClass): c = Vhat/W self.convergence_diagnostic.append(c) self.convergence_diagnosticx.append(i - self.convergence_period/2) + return c + + def convergence_test(self, i, sampler, nburn): + if i < self.convergence_burnin_fraction*nburn: + return False + if np.mod(i+1, self.convergence_period) == 0: + return False + c = self.get_convergence_statistic(i, sampler) if np.all(c < self.convergence_threshold): self.convergence_number += 1 - + else: + self.convergence_number = 0 return self.convergence_number > self.convergence_threshold_number + def check_production_convergence(self, k): + bools = np.any( + np.array(self.convergence_diagnostic)[k:, :] + > self.convergence_prod_threshold, axis=1) + if np.any(bools): + logging.warning( + '{} convergence tests in the production run of {} failed' + .format(np.sum(bools), len(bools))) + def run_sampler(self, sampler, p0, nprod=0, nburn=0): if hasattr(self, 'convergence_period'): - for i, result in enumerate(tqdm( - sampler.sample(p0, iterations=nburn+nprod), - total=nburn+nprod)): - converged = self.convergence_test(i, sampler, nburn) + converged = False + logging.info('Running {} burn-in steps with convergence testing' + .format(nburn)) + iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn) + for i, output in enumerate(iterator): + if converged: + logging.info( + 'Converged at {} before max number {} of steps reached' + .format(i, nburn)) + self.convergence_idx = i + break + else: + converged = self.convergence_test(i, sampler, nburn) + iterator.close() + logging.info('Running {} production steps'.format(nprod)) + j = nburn + k = len(self.convergence_diagnostic) + for result in tqdm(sampler.sample(output[0], iterations=nprod), + total=nprod): + self.get_convergence_statistic(j, sampler) + j += 1 + self.check_production_convergence(k) return sampler else: for result in tqdm(sampler.sample(p0, iterations=nburn+nprod), @@ -329,7 +389,7 @@ class MCMCSearch(BaseSearchClass): if create_plots: fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols, - burnin_idx=nburn, **kwargs) + nprod=nprod, **kwargs) fig.tight_layout() fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label), dpi=200) @@ -572,7 +632,7 @@ class MCMCSearch(BaseSearchClass): raise ValueError("dist_type {} unknown".format(dist_type)) def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0, - lw=0.1, burnin_idx=None, add_det_stat_burnin=False, + lw=0.1, nprod=None, add_det_stat_burnin=False, fig=None, axes=None, xoffset=0, plot_det_stat=True, context='classic', subtractions=None, labelpad=0.05): """ Plot all the chains from a sampler """ @@ -608,13 +668,18 @@ class MCMCSearch(BaseSearchClass): for i in range(2, ndim+1)] idxs = np.arange(chain.shape[1]) + burnin_idx = chain.shape[1] - nprod + if hasattr(self, 'convergence_idx'): + convergence_idx = self.convergence_idx + else: + convergence_idx = burnin_idx if ndim > 1: for i in range(ndim): axes[i].ticklabel_format(useOffset=False, axis='y') cs = chain[:, :, i].T - if burnin_idx: - axes[i].plot(xoffset+idxs[:burnin_idx], - cs[:burnin_idx]-subtractions[i], + if burnin_idx > 0: + axes[i].plot(xoffset+idxs[:convergence_idx], + cs[:convergence_idx]-subtractions[i], color="r", alpha=alpha, lw=lw) axes[i].plot(xoffset+idxs[burnin_idx:], @@ -1665,7 +1730,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): if create_plots: fig, axes = self.plot_walkers( sampler, symbols=self.theta_symbols, fig=fig, axes=axes, - burnin_idx=nburn, xoffset=nsteps_total, **kwargs) + nprod=nprod, xoffset=nsteps_total, **kwargs) for ax in axes[:self.ndim]: ax.axvline(nsteps_total, color='k', ls='--', lw=0.25) diff --git a/tests.py b/tests.py index 30cc5e798e337fd1af2eff25ae956412b2951780..7d639413f41cea9abdd45682940862c469fceb9b 100644 --- a/tests.py +++ b/tests.py @@ -214,15 +214,15 @@ class TestMCMCSearch(Test): Writer.make_data() predicted_FS = Writer.predict_fstat() - theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-7*F0)}, - 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-3*F1)}, + theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)}, + 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-9*F1)}, 'F2': F2, 'Alpha': Alpha, 'Delta': Delta} search = pyfstat.MCMCSearch( label=self.label, outdir=outdir, theta_prior=theta, tref=tref, sftfilepath='{}/*{}*sft'.format(Writer.outdir, Writer.label), minStartTime=minStartTime, maxStartTime=maxStartTime, - nsteps=[500, 100], nwalkers=100, ntemps=2, log10temperature_min=-1) + nsteps=[100, 100], nwalkers=100, ntemps=2, log10temperature_min=-1) search.setup_convergence_testing() search.run(create_plots=True) search.plot_corner(add_prior=True)