diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index 06cdfc28e5556fe2033186e6422d0b59482ba98c..fefa8e7b841b1fdb295d40b41adca5fd0150716e 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -239,61 +239,63 @@ class MCMCSearch(core.BaseSearchClass): pass return sampler - def setup_convergence_testing( - self, convergence_period=10, convergence_length=10, - convergence_burnin_fraction=0.25, convergence_threshold_number=10, - convergence_threshold=1.2, convergence_prod_threshold=2, - convergence_plot_upper_lim=2, convergence_early_stopping=True): + def setup_burnin_convergence_testing( + self, n=10, test_type='autocorr', windowed=False, **kwargs): """ If called, convergence testing is used during the MCMC simulation - This uses the Gelmanr-Rubin statistic based on the ratio of between and - within walkers variance. The original statistic was developed for - multiple (independent) MCMC simulations, in this context we simply use - the walkers - Parameters ---------- - convergence_period: int - period (in number of steps) at which to test convergence - convergence_length: int - number of steps to use in testing convergence - this should be - large enough to measure the variance, but if it is too long - this will result in incorect early convergence tests - convergence_burnin_fraction: float [0, 1] - the fraction of the burn-in period after which to start testing - convergence_threshold_number: int - the number of consecutive times where the test passes after which - to break the burn-in and go to production - convergence_threshold: float - the threshold to use in diagnosing convergence. Gelman & Rubin - recomend a value of 1.2, 1.1 for strict convergence - convergence_prod_threshold: float - the threshold to test the production values with - convergence_plot_upper_lim: float - the upper limit to use in the diagnostic plot - convergence_early_stopping: bool - if true, stop the burnin early if convergence is reached + n: int + Number of steps after which to test convergence + test_type: str ['autocorr', 'GR'] + If 'autocorr' use the exponential autocorrelation time (kwargs + passed to `get_autocorr_convergence`). If 'GR' use the Gelman-Rubin + statistic (kwargs passed to `get_GR_convergence`) + windowed: bool + If True, only calculate the convergence test in a window of length + `n` """ - - if convergence_length > convergence_period: - raise ValueError('convergence_length must be < convergence_period') logging.info('Setting up convergence testing') - self.convergence_length = convergence_length - self.convergence_period = convergence_period - self.convergence_burnin_fraction = convergence_burnin_fraction - self.convergence_prod_threshold = convergence_prod_threshold + self.convergence_n = n + self.convergence_windowed = windowed + self.convergence_test_type = test_type + self.convergence_kwargs = kwargs self.convergence_diagnostic = [] self.convergence_diagnosticx = [] - self.convergence_threshold_number = convergence_threshold_number - self.convergence_threshold = convergence_threshold - self.convergence_number = 0 - self.convergence_plot_upper_lim = convergence_plot_upper_lim - self.convergence_early_stopping = convergence_early_stopping - - def _get_convergence_statistic(self, i, sampler): - s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :] - N = float(self.convergence_length) + if test_type in ['autocorr']: + self._get_convergence_test = self.test_autocorr_convergence + elif test_type in ['GR']: + self._get_convergence_test= self.test_GR_convergence + else: + raise ValueError('test_type {} not understood'.format(test_type)) + + def test_autocorr_convergence(self, i, sampler, test=True, n_cut=5): + try: + acors = np.zeros((self.ntemps, self.ndim)) + for temp in range(self.ntemps): + if self.convergence_windowed: + j = i-self.convergence_n + else: + j = 0 + x = np.mean(sampler.chain[temp, :, j:i, :], axis=0) + acors[temp, :] = emcee.autocorr.exponential_time(x) + c = np.max(acors, axis=0) + except emcee.autocorr.AutocorrError: + c = np.zeros(self.ndim) + np.nan + + self.convergence_diagnosticx.append(i - self.convergence_n/2.) + self.convergence_diagnostic.append(list(c)) + + if test: + return i > n_cut * np.max(c) + + def test_GR_convergence(self, i, sampler, test=True, R=1.1): + if self.convergence_windowed: + s = sampler.chain[0, :, i-self.convergence_n+1:i+1, :] + else: + s = sampler.chain[0, :, :i+1, :] + N = float(self.convergence_n) M = float(self.nwalkers) W = np.mean(np.var(s, axis=1), axis=0) per_walker_mean = np.mean(s, axis=1) @@ -302,58 +304,45 @@ class MCMCSearch(core.BaseSearchClass): Vhat = (N-1)/N * W + (M+1)/(M*N) * B c = np.sqrt(Vhat/W) self.convergence_diagnostic.append(c) - self.convergence_diagnosticx.append(i - self.convergence_length/2) - return c + self.convergence_diagnosticx.append(i - self.convergence_n/2.) - def _burnin_convergence_test(self, i, sampler, nburn): - if i < self.convergence_burnin_fraction*nburn: - return False - if np.mod(i+1, self.convergence_period) != 0: + if test and np.max(c) < R: + return True + else: return False - c = self._get_convergence_statistic(i, sampler) - if np.all(c < self.convergence_threshold): - self.convergence_number += 1 + + def _test_convergence(self, i, sampler, **kwargs): + if np.mod(i+1, self.convergence_n) == 0: + return self._get_convergence_test(i, sampler, **kwargs) else: - self.convergence_number = 0 - if self.convergence_early_stopping: - return self.convergence_number > self.convergence_threshold_number - - def _prod_convergence_test(self, i, sampler, nburn): - testA = i > nburn + self.convergence_length - testB = np.mod(i+1, self.convergence_period) == 0 - if testA and testB: - self._get_convergence_statistic(i, sampler) - - def _check_production_convergence(self, k): - bools = np.any( - np.array(self.convergence_diagnostic)[k:, :] - > self.convergence_prod_threshold, axis=1) - if np.any(bools): - logging.warning( - '{} convergence tests in the production run of {} failed' - .format(np.sum(bools), len(bools))) + return False + + def _run_sampler_with_conv_test(self, sampler, p0, nprod=0, nburn=0): + logging.info('Running {} burn-in steps with convergence testing' + .format(nburn)) + iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn) + for i, output in enumerate(iterator): + if self._test_convergence(i, sampler, test=True, + **self.convergence_kwargs): + logging.info( + 'Converged at {} before max number {} of steps reached' + .format(i, nburn)) + self.convergence_idx = i + break + iterator.close() + logging.info('Running {} production steps'.format(nprod)) + j = nburn + iterator = tqdm(sampler.sample(output[0], iterations=nprod), + total=nprod) + for result in iterator: + self._test_convergence(j, sampler, test=False, + **self.convergence_kwargs) + j += 1 + return sampler def _run_sampler(self, sampler, p0, nprod=0, nburn=0): - if hasattr(self, 'convergence_period'): - logging.info('Running {} burn-in steps with convergence testing' - .format(nburn)) - iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn) - for i, output in enumerate(iterator): - if self._burnin_convergence_test(i, sampler, nburn): - logging.info( - 'Converged at {} before max number {} of steps reached' - .format(i, nburn)) - self.convergence_idx = i - break - iterator.close() - logging.info('Running {} production steps'.format(nprod)) - j = nburn - k = len(self.convergence_diagnostic) - for result in tqdm(sampler.sample(output[0], iterations=nprod), - total=nprod): - self._prod_convergence_test(j, sampler, nburn) - j += 1 - self._check_production_convergence(k) + if hasattr(self, 'convergence_n'): + self._run_sampler_with_conv_test(sampler, p0, nprod, nburn) else: for result in tqdm(sampler.sample(p0, iterations=nburn+nprod), total=nburn+nprod): @@ -956,9 +945,11 @@ class MCMCSearch(core.BaseSearchClass): zorder=-10) ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-C0', zorder=-10) - ax.set_ylabel('PSRF') + if self.convergence_test_type == 'autocorr': + ax.set_ylabel(r'$\tau_\mathrm{exp}$') + elif self.convergence_test_type == 'GR': + ax.set_ylabel('PSRF') ax.ticklabel_format(useOffset=False) - ax.set_ylim(0.5, self.convergence_plot_upper_lim) else: axes[0].ticklabel_format(useOffset=False, axis='y') cs = chain[:, :, temp].T