Commit f7aab64d authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Minor polishing to MCMC searches

- Adds chains to saved data
- Add catch for when corner plots error
parent 37041071
...@@ -492,6 +492,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -492,6 +492,7 @@ class MCMCSearch(core.BaseSearchClass):
self.lnprobs = d['lnprobs'] self.lnprobs = d['lnprobs']
self.lnlikes = d['lnlikes'] self.lnlikes = d['lnlikes']
self.all_lnlikelihood = d['all_lnlikelihood'] self.all_lnlikelihood = d['all_lnlikelihood']
self.chain = d['chain']
return return
self._initiate_search_object() self._initiate_search_object()
...@@ -533,21 +534,27 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -533,21 +534,27 @@ class MCMCSearch(core.BaseSearchClass):
logging.info('Running final burn and prod with {} steps'.format( logging.info('Running final burn and prod with {} steps'.format(
nburn+nprod)) nburn+nprod))
sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod) sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
if create_plots: if create_plots:
fig, axes = self._plot_walkers(sampler, nprod=nprod, **kwargs) try:
fig.tight_layout() fig, axes = self._plot_walkers(sampler, nprod=nprod, **kwargs)
fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label), fig.tight_layout()
) fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label))
except RuntimeError as e:
logging.warning("Failed to save walker plots due to Erro {}"
.format(e))
samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim)) samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
lnprobs = sampler.logprobability[0, :, nburn:].reshape((-1)) lnprobs = sampler.logprobability[0, :, nburn:].reshape((-1))
lnlikes = sampler.loglikelihood[0, :, nburn:].reshape((-1)) lnlikes = sampler.loglikelihood[0, :, nburn:].reshape((-1))
all_lnlikelihood = sampler.loglikelihood[:, :, nburn:] all_lnlikelihood = sampler.loglikelihood[:, :, nburn:]
self.samples = samples self.samples = samples
self.chain = sampler.chain
self.lnprobs = lnprobs self.lnprobs = lnprobs
self.lnlikes = lnlikes self.lnlikes = lnlikes
self.all_lnlikelihood = all_lnlikelihood self.all_lnlikelihood = all_lnlikelihood
self._save_data(sampler, samples, lnprobs, lnlikes, all_lnlikelihood) self._save_data(sampler, samples, lnprobs, lnlikes, all_lnlikelihood,
sampler.chain)
return sampler return sampler
def _get_rescale_multiplier_for_key(self, key): def _get_rescale_multiplier_for_key(self, key):
...@@ -1215,11 +1222,13 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -1215,11 +1222,13 @@ class MCMCSearch(core.BaseSearchClass):
maxStartTime=self.maxStartTime) maxStartTime=self.maxStartTime)
return d return d
def _save_data(self, sampler, samples, lnprobs, lnlikes, all_lnlikelihood): def _save_data(self, sampler, samples, lnprobs, lnlikes, all_lnlikelihood,
chain):
d = self._get_data_dictionary_to_save() d = self._get_data_dictionary_to_save()
d['samples'] = samples d['samples'] = samples
d['lnprobs'] = lnprobs d['lnprobs'] = lnprobs
d['lnlikes'] = lnlikes d['lnlikes'] = lnlikes
d['chain'] = chain
d['all_lnlikelihood'] = all_lnlikelihood d['all_lnlikelihood'] = all_lnlikelihood
if os.path.isfile(self.pickle_path): if os.path.isfile(self.pickle_path):
...@@ -1254,6 +1263,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -1254,6 +1263,7 @@ class MCMCSearch(core.BaseSearchClass):
old_d.pop('lnprobs') old_d.pop('lnprobs')
old_d.pop('lnlikes') old_d.pop('lnlikes')
old_d.pop('all_lnlikelihood') old_d.pop('all_lnlikelihood')
old_d.pop('chain')
for key in 'minStartTime', 'maxStartTime': for key in 'minStartTime', 'maxStartTime':
if new_d[key] is None: if new_d[key] is None:
...@@ -1616,7 +1626,7 @@ class MCMCGlitchSearch(MCMCSearch): ...@@ -1616,7 +1626,7 @@ class MCMCGlitchSearch(MCMCSearch):
'multiplier': 1/86400., 'multiplier': 1/86400.,
'subtractor': 'minStartTime', 'subtractor': 'minStartTime',
'unit': 'day', 'unit': 'day',
'label': '$t^{g}_0$ \n [days]'} 'label': '$t^{g}_0$ \n [d]'}
) )
@helper_functions.initializer @helper_functions.initializer
...@@ -2108,6 +2118,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): ...@@ -2108,6 +2118,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
self.lnprobs = d['lnprobs'] self.lnprobs = d['lnprobs']
self.lnlikes = d['lnlikes'] self.lnlikes = d['lnlikes']
self.all_lnlikelihood = d['all_lnlikelihood'] self.all_lnlikelihood = d['all_lnlikelihood']
self.chain = d['chain']
self.nsegs = run_setup[-1][1] self.nsegs = run_setup[-1][1]
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment