diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index 1c1d269fbde14e027ea90b8d505577dc5938027f..cf548a97a3b5e61465c829eb49f8e050ffce98ca 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -216,7 +216,8 @@ class MCMCSearch(BaseSearchClass): def setup_convergence_testing( self, convergence_period=10, convergence_length=10, convergence_burnin_fraction=0.25, convergence_threshold_number=10, - convergence_threshold=1.2, convergence_prod_threshold=2): + convergence_threshold=1.2, convergence_prod_threshold=2, + convergence_plot_upper_lim=2): """ If called, convergence testing is used during the MCMC simulation @@ -243,6 +244,8 @@ class MCMCSearch(BaseSearchClass): recomend a value of 1.2, 1.1 for strict convergence convergence_prod_threshold: float the threshold to test the production values with + convergence_plot_upper_lim: float + the upper limit to use in the diagnostic plot """ if convergence_length > convergence_period: @@ -257,6 +260,7 @@ class MCMCSearch(BaseSearchClass): self.convergence_threshold_number = convergence_threshold_number self.convergence_threshold = convergence_threshold self.convergence_number = 0 + self.convergence_plot_upper_lim = convergence_plot_upper_lim def get_convergence_statistic(self, i, sampler): s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :] @@ -268,15 +272,15 @@ class MCMCSearch(BaseSearchClass): B_over_n = between_std**2 / self.convergence_period Vhat = ((self.convergence_period-1.)/self.convergence_period * W + B_over_n + B_over_n / float(self.nwalkers)) - c = Vhat/W + c = np.sqrt(Vhat/W) self.convergence_diagnostic.append(c) - self.convergence_diagnosticx.append(i - self.convergence_period/2) + self.convergence_diagnosticx.append(i - self.convergence_length/2) return c - def convergence_test(self, i, sampler, nburn): + def burnin_convergence_test(self, i, sampler, nburn): if i < self.convergence_burnin_fraction*nburn: return False - if np.mod(i+1, self.convergence_period) == 0: + if np.mod(i+1, self.convergence_period) != 0: return False c = self.get_convergence_statistic(i, sampler) if np.all(c < self.convergence_threshold): @@ -285,6 +289,12 @@ class MCMCSearch(BaseSearchClass): self.convergence_number = 0 return self.convergence_number > self.convergence_threshold_number + def prod_convergence_test(self, i, sampler, nburn): + testA = i > nburn + self.convergence_length + testB = np.mod(i+1, self.convergence_period) == 0 + if testA and testB: + self.get_convergence_statistic(i, sampler) + def check_production_convergence(self, k): bools = np.any( np.array(self.convergence_diagnostic)[k:, :] @@ -296,26 +306,23 @@ class MCMCSearch(BaseSearchClass): def run_sampler(self, sampler, p0, nprod=0, nburn=0): if hasattr(self, 'convergence_period'): - converged = False logging.info('Running {} burn-in steps with convergence testing' .format(nburn)) iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn) for i, output in enumerate(iterator): - if converged: + if self.burnin_convergence_test(i, sampler, nburn): logging.info( 'Converged at {} before max number {} of steps reached' .format(i, nburn)) self.convergence_idx = i break - else: - converged = self.convergence_test(i, sampler, nburn) iterator.close() logging.info('Running {} production steps'.format(nprod)) j = nburn k = len(self.convergence_diagnostic) for result in tqdm(sampler.sample(output[0], iterations=nprod), total=nprod): - self.get_convergence_statistic(j, sampler) + self.prod_convergence_test(j, sampler, nburn) j += 1 self.check_production_convergence(k) return sampler @@ -365,7 +372,7 @@ class MCMCSearch(BaseSearchClass): **kwargs) fig.tight_layout() fig.savefig('{}/{}_init_{}_walkers.png'.format( - self.outdir, self.label, j), dpi=200) + self.outdir, self.label, j), dpi=400) p0 = self.get_new_p0(sampler) p0 = self.apply_corrections_to_p0(p0) @@ -696,9 +703,12 @@ class MCMCSearch(BaseSearchClass): ax = axes[i].twinx() c_x = np.array(self.convergence_diagnosticx) c_y = np.array(self.convergence_diagnostic) - ax.plot(c_x, c_y[:, i], '-b') + break_idx = np.argmin(np.abs(c_x - burnin_idx)) + ax.plot(c_x[:break_idx], c_y[:break_idx, i], '-b') + ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-b') + ax.set_ylabel('PSRF') ax.ticklabel_format(useOffset=False) - ax.set_ylim(1, 5) + ax.set_ylim(1, self.convergence_plot_upper_lim) else: axes[0].ticklabel_format(useOffset=False, axis='y') cs = chain[:, :, temp].T