diff --git a/pyfstat.py b/pyfstat.py index 7d7781cd692359cc8b9e58d83161815ea8a3ef7c..9bb2affb1f9387058b6647cbb578ef43cd253d2a 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -662,7 +662,16 @@ class MCMCSearch(BaseSearchClass): return p0 - def run(self, proposal_scale_factor=None): + def run_sampler_with_progress_bar(self, sampler, ns, p0): + try: + from tqdm import tqdm + for result in tqdm(sampler.sample(p0, iterations=ns), total=ns): + pass + except ImportError: + sampler.run_mcmc(p0, ns) + return sampler + + def run(self, proposal_scale_factor=2): if self.old_data_is_okay_to_use is True: logging.warning('Using saved data from {}'.format( @@ -689,7 +698,7 @@ class MCMCSearch(BaseSearchClass): for j, n in enumerate(self.nsteps[:-2]): logging.info('Running {}/{} initialisation with {} steps'.format( j+1, ninit_steps, n)) - sampler.run_mcmc(p0, n) + sampler = self.run_sampler_with_progress_bar(sampler, n, p0) logging.info("Mean acceptance fraction: {0:.3f}" .format(np.mean(sampler.acceptance_fraction))) if self.ntemps > 1: @@ -704,11 +713,14 @@ class MCMCSearch(BaseSearchClass): self.check_initial_points(p0) sampler.reset() - nburn = self.nsteps[-2] + if len(self.nsteps) > 1: + nburn = self.nsteps[-2] + else: + nburn = 0 nprod = self.nsteps[-1] logging.info('Running final burn and prod with {} steps'.format( nburn+nprod)) - sampler.run_mcmc(p0, nburn+nprod) + sampler = self.run_sampler_with_progress_bar(sampler, nburn+nprod, p0) logging.info("Mean acceptance fraction: {0:.3f}" .format(np.mean(sampler.acceptance_fraction))) if self.ntemps > 1: