Skip to content
Snippets Groups Projects
Commit a6d1a00b authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adds progress bar

parent d1a8b85e
No related branches found
No related tags found
No related merge requests found
...@@ -662,7 +662,16 @@ class MCMCSearch(BaseSearchClass): ...@@ -662,7 +662,16 @@ class MCMCSearch(BaseSearchClass):
return p0 return p0
def run(self, proposal_scale_factor=None): def run_sampler_with_progress_bar(self, sampler, ns, p0):
try:
from tqdm import tqdm
for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
pass
except ImportError:
sampler.run_mcmc(p0, ns)
return sampler
def run(self, proposal_scale_factor=2):
if self.old_data_is_okay_to_use is True: if self.old_data_is_okay_to_use is True:
logging.warning('Using saved data from {}'.format( logging.warning('Using saved data from {}'.format(
...@@ -689,7 +698,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -689,7 +698,7 @@ class MCMCSearch(BaseSearchClass):
for j, n in enumerate(self.nsteps[:-2]): for j, n in enumerate(self.nsteps[:-2]):
logging.info('Running {}/{} initialisation with {} steps'.format( logging.info('Running {}/{} initialisation with {} steps'.format(
j+1, ninit_steps, n)) j+1, ninit_steps, n))
sampler.run_mcmc(p0, n) sampler = self.run_sampler_with_progress_bar(sampler, n, p0)
logging.info("Mean acceptance fraction: {0:.3f}" logging.info("Mean acceptance fraction: {0:.3f}"
.format(np.mean(sampler.acceptance_fraction))) .format(np.mean(sampler.acceptance_fraction)))
if self.ntemps > 1: if self.ntemps > 1:
...@@ -704,11 +713,14 @@ class MCMCSearch(BaseSearchClass): ...@@ -704,11 +713,14 @@ class MCMCSearch(BaseSearchClass):
self.check_initial_points(p0) self.check_initial_points(p0)
sampler.reset() sampler.reset()
if len(self.nsteps) > 1:
nburn = self.nsteps[-2] nburn = self.nsteps[-2]
else:
nburn = 0
nprod = self.nsteps[-1] nprod = self.nsteps[-1]
logging.info('Running final burn and prod with {} steps'.format( logging.info('Running final burn and prod with {} steps'.format(
nburn+nprod)) nburn+nprod))
sampler.run_mcmc(p0, nburn+nprod) sampler = self.run_sampler_with_progress_bar(sampler, nburn+nprod, p0)
logging.info("Mean acceptance fraction: {0:.3f}" logging.info("Mean acceptance fraction: {0:.3f}"
.format(np.mean(sampler.acceptance_fraction))) .format(np.mean(sampler.acceptance_fraction)))
if self.ntemps > 1: if self.ntemps > 1:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment