Skip to content
Snippets Groups Projects
Commit f7aab64d authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Minor polishing to MCMC searches

- Adds chains to saved data
- Add catch for when corner plots error
parent 37041071
Branches
Tags
No related merge requests found
......@@ -492,6 +492,7 @@ class MCMCSearch(core.BaseSearchClass):
self.lnprobs = d['lnprobs']
self.lnlikes = d['lnlikes']
self.all_lnlikelihood = d['all_lnlikelihood']
self.chain = d['chain']
return
self._initiate_search_object()
......@@ -533,21 +534,27 @@ class MCMCSearch(core.BaseSearchClass):
logging.info('Running final burn and prod with {} steps'.format(
nburn+nprod))
sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
if create_plots:
try:
fig, axes = self._plot_walkers(sampler, nprod=nprod, **kwargs)
fig.tight_layout()
fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
)
fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label))
except RuntimeError as e:
logging.warning("Failed to save walker plots due to Erro {}"
.format(e))
samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
lnprobs = sampler.logprobability[0, :, nburn:].reshape((-1))
lnlikes = sampler.loglikelihood[0, :, nburn:].reshape((-1))
all_lnlikelihood = sampler.loglikelihood[:, :, nburn:]
self.samples = samples
self.chain = sampler.chain
self.lnprobs = lnprobs
self.lnlikes = lnlikes
self.all_lnlikelihood = all_lnlikelihood
self._save_data(sampler, samples, lnprobs, lnlikes, all_lnlikelihood)
self._save_data(sampler, samples, lnprobs, lnlikes, all_lnlikelihood,
sampler.chain)
return sampler
def _get_rescale_multiplier_for_key(self, key):
......@@ -1215,11 +1222,13 @@ class MCMCSearch(core.BaseSearchClass):
maxStartTime=self.maxStartTime)
return d
def _save_data(self, sampler, samples, lnprobs, lnlikes, all_lnlikelihood):
def _save_data(self, sampler, samples, lnprobs, lnlikes, all_lnlikelihood,
chain):
d = self._get_data_dictionary_to_save()
d['samples'] = samples
d['lnprobs'] = lnprobs
d['lnlikes'] = lnlikes
d['chain'] = chain
d['all_lnlikelihood'] = all_lnlikelihood
if os.path.isfile(self.pickle_path):
......@@ -1254,6 +1263,7 @@ class MCMCSearch(core.BaseSearchClass):
old_d.pop('lnprobs')
old_d.pop('lnlikes')
old_d.pop('all_lnlikelihood')
old_d.pop('chain')
for key in 'minStartTime', 'maxStartTime':
if new_d[key] is None:
......@@ -1616,7 +1626,7 @@ class MCMCGlitchSearch(MCMCSearch):
'multiplier': 1/86400.,
'subtractor': 'minStartTime',
'unit': 'day',
'label': '$t^{g}_0$ \n [days]'}
'label': '$t^{g}_0$ \n [d]'}
)
@helper_functions.initializer
......@@ -2108,6 +2118,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
self.lnprobs = d['lnprobs']
self.lnlikes = d['lnlikes']
self.all_lnlikelihood = d['all_lnlikelihood']
self.chain = d['chain']
self.nsegs = run_setup[-1][1]
return
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment