diff --git a/examples/glitch_examples/semicoherent_glitch_robust_directed_MCMC_search_on_1_glitch.py b/examples/glitch_examples/semicoherent_glitch_robust_directed_MCMC_search_on_1_glitch.py index 345caa99e7ea0be2a09cb9277653420838d2231a..1de4977f110676ae4928d19a0c6458bd807f92a1 100644 --- a/examples/glitch_examples/semicoherent_glitch_robust_directed_MCMC_search_on_1_glitch.py +++ b/examples/glitch_examples/semicoherent_glitch_robust_directed_MCMC_search_on_1_glitch.py @@ -5,7 +5,7 @@ import gridcorner import time from make_simulated_data import tstart, duration, tref, F0, F1, F2, Alpha, Delta, delta_F0, dtglitch, outdir -plt.style.use('./paper.mplstyle') +#plt.style.use('./paper.mplstyle') label = 'semicoherent_glitch_robust_directed_MCMC_search_on_1_glitch' @@ -60,7 +60,9 @@ mcmc.run() dT = time.time() - t1 fig_and_axes = gridcorner._get_fig_and_axes(4, 2, 0.05) mcmc.plot_corner(label_offset=0.25, truths=[0, 0, 0, 0], - fig_and_axes=fig_and_axes) + fig_and_axes=fig_and_axes, quantiles=(0.16, 0.84)) +#mcmc.plot_chainconsumer(truth=[0, 0, 0, 0], label_offset=0.5) + mcmc.print_summary() print('Prior widths =', F0_width, F1_width) diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index 9dcc9cef0095777730f2b1ed2115b817f871221e..789da3404aeafb78fa4fce2d3ad8d619780ae1f1 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -787,6 +787,61 @@ class MCMCSearch(core.BaseSearchClass): else: return fig, axes + def plot_chainconsumer( + self, save_fig=True, label_offset=0.25, dpi=300, **kwargs): + """ Generate a corner plot of the posterior using chainconsumer + + Parameters + ---------- + dpi: int + Passed to plt.savefig + **kwargs: + Passed to chainconsumer.plotter.plot + + """ + + if 'truths' in kwargs and len(kwargs['truths']) != self.ndim: + logging.warning('len(Truths) != ndim, Truths will be ignored') + kwargs['truths'] = None + + samples_plt = copy.copy(self.samples) + labels = self._get_labels(newline_units=True) + + samples_plt = self._scale_samples(samples_plt, self.theta_keys) + + import chainconsumer + c = chainconsumer.ChainConsumer() + c.add_chain(samples_plt, parameters=labels) + c.configure(smooth=0, summary=False, sigma2d=True) + fig = c.plotter.plot(**kwargs) + + axes_list = fig.get_axes() + axes = np.array(axes_list).reshape(self.ndim, self.ndim) + plt.draw() + for ax in axes[:, 0]: + ax.yaxis.set_label_coords(-label_offset, 0.5) + for ax in axes[-1, :]: + ax.xaxis.set_label_coords(0.5, -label_offset) + for ax in axes_list: + ax.set_rasterized(True) + ax.set_rasterization_zorder(-10) + + #for tick in ax.xaxis.get_major_ticks(): + # #tick.label.set_fontsize(8) + # tick.label.set_rotation('horizontal') + #for tick in ax.yaxis.get_major_ticks(): + # #tick.label.set_fontsize(8) + # tick.label.set_rotation('vertical') + + plt.tight_layout(h_pad=0.0, w_pad=0.0) + fig.subplots_adjust(hspace=0.05, wspace=0.05) + + if save_fig: + fig.savefig('{}/{}_corner.png'.format( + self.outdir, self.label), dpi=dpi) + else: + return fig + def _add_prior_to_corner(self, axes, samples, add_prior): for i, key in enumerate(self.theta_keys): ax = axes[i][i]