diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index 59d8b837aee6acb5f14f258d9495faada343e6c7..9dcc9cef0095777730f2b1ed2115b817f871221e 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -492,6 +492,7 @@ class MCMCSearch(core.BaseSearchClass): self.lnprobs = d['lnprobs'] self.lnlikes = d['lnlikes'] self.all_lnlikelihood = d['all_lnlikelihood'] + self.chain = d['chain'] return self._initiate_search_object() @@ -533,21 +534,27 @@ class MCMCSearch(core.BaseSearchClass): logging.info('Running final burn and prod with {} steps'.format( nburn+nprod)) sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod) + if create_plots: - fig, axes = self._plot_walkers(sampler, nprod=nprod, **kwargs) - fig.tight_layout() - fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label), - ) + try: + fig, axes = self._plot_walkers(sampler, nprod=nprod, **kwargs) + fig.tight_layout() + fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label)) + except RuntimeError as e: + logging.warning("Failed to save walker plots due to Erro {}" + .format(e)) samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim)) lnprobs = sampler.logprobability[0, :, nburn:].reshape((-1)) lnlikes = sampler.loglikelihood[0, :, nburn:].reshape((-1)) all_lnlikelihood = sampler.loglikelihood[:, :, nburn:] self.samples = samples + self.chain = sampler.chain self.lnprobs = lnprobs self.lnlikes = lnlikes self.all_lnlikelihood = all_lnlikelihood - self._save_data(sampler, samples, lnprobs, lnlikes, all_lnlikelihood) + self._save_data(sampler, samples, lnprobs, lnlikes, all_lnlikelihood, + sampler.chain) return sampler def _get_rescale_multiplier_for_key(self, key): @@ -1215,11 +1222,13 @@ class MCMCSearch(core.BaseSearchClass): maxStartTime=self.maxStartTime) return d - def _save_data(self, sampler, samples, lnprobs, lnlikes, all_lnlikelihood): + def _save_data(self, sampler, samples, lnprobs, lnlikes, all_lnlikelihood, + chain): d = self._get_data_dictionary_to_save() d['samples'] = samples d['lnprobs'] = lnprobs d['lnlikes'] = lnlikes + d['chain'] = chain d['all_lnlikelihood'] = all_lnlikelihood if os.path.isfile(self.pickle_path): @@ -1254,6 +1263,7 @@ class MCMCSearch(core.BaseSearchClass): old_d.pop('lnprobs') old_d.pop('lnlikes') old_d.pop('all_lnlikelihood') + old_d.pop('chain') for key in 'minStartTime', 'maxStartTime': if new_d[key] is None: @@ -1616,7 +1626,7 @@ class MCMCGlitchSearch(MCMCSearch): 'multiplier': 1/86400., 'subtractor': 'minStartTime', 'unit': 'day', - 'label': '$t^{g}_0$ \n [days]'} + 'label': '$t^{g}_0$ \n [d]'} ) @helper_functions.initializer @@ -2108,6 +2118,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): self.lnprobs = d['lnprobs'] self.lnlikes = d['lnlikes'] self.all_lnlikelihood = d['all_lnlikelihood'] + self.chain = d['chain'] self.nsegs = run_setup[-1][1] return