diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index 600023764f0c4ad0443fa038732e4a52572c5c1a..9bb1c223428ad7e382a93ab90029d31d00ec3325 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -524,8 +524,9 @@ class MCMCSearch(core.BaseSearchClass): ---------- figsize: tuple (7, 7) Figure size in inches (passed to plt.subplots) - add_prior: bool - If true, plot the prior as a red line + add_prior: bool, str + If true, plot the prior as a red line. If 'full' then for uniform + priors plot the full extent of the prior. nstds: float The number of standard deviations to plot centered on the mean label_offset: float @@ -546,7 +547,7 @@ class MCMCSearch(core.BaseSearchClass): save_fig: bool If true, save the figure, else return the fig, axes - Note: kwargs are passed on to corner.coner + Note: kwargs are passed on to corner.corner """ @@ -599,6 +600,10 @@ class MCMCSearch(core.BaseSearchClass): else: _range = None + hist_kwargs = kwargs.pop('hist_kwargs', dict()) + if 'normed' not in hist_kwargs: + hist_kwargs['normed'] = True + fig_triangle = corner.corner(samples_plt, labels=labels, fig=fig, @@ -610,6 +615,7 @@ class MCMCSearch(core.BaseSearchClass): data_kwargs={'alpha': 0.1, 'ms': 0.5}, range=_range, + hist_kwargs=hist_kwargs, **kwargs) axes_list = fig_triangle.get_axes() @@ -626,7 +632,7 @@ class MCMCSearch(core.BaseSearchClass): fig.subplots_adjust(hspace=0.05, wspace=0.05) if add_prior: - self._add_prior_to_corner(axes, self.samples) + self._add_prior_to_corner(axes, self.samples, add_prior) if save_fig: fig_triangle.savefig('{}/{}_corner.png'.format( @@ -634,19 +640,30 @@ class MCMCSearch(core.BaseSearchClass): else: return fig, axes - def _add_prior_to_corner(self, axes, samples): + def _add_prior_to_corner(self, axes, samples, add_prior): for i, key in enumerate(self.theta_keys): ax = axes[i][i] - xlim = ax.get_xlim() s = samples[:, i] - prior = self._generic_lnprior(**self.theta_prior[key]) - x = np.linspace(s.min(), s.max(), 100) + lnprior = self._generic_lnprior(**self.theta_prior[key]) + if add_prior == 'full' and self.theta_prior[key]['type'] == 'unif': + lower = self.theta_prior[key]['lower'] + upper = self.theta_prior[key]['upper'] + r = upper-lower + xlim = [lower-0.05*r, upper+0.05*r] + x = np.linspace(xlim[0], xlim[1], 1000) + else: + xlim = ax.get_xlim() + x = np.linspace(s.min(), s.max(), 1000) multiplier = self._get_rescale_multiplier_for_key(key) subtractor = self._get_rescale_subtractor_for_key(key) - ax2 = ax.twinx() - ax2.get_yaxis().set_visible(False) - ax2.plot((x-subtractor)*multiplier, [prior(xi) for xi in x], '-r') - ax2.set_xlim(xlim) + ax.plot((x-subtractor)*multiplier, + [np.exp(lnprior(xi)) for xi in x], '-C3', + label='prior') + + for j in range(i, self.ndim): + axes[j][i].set_xlim(xlim[0], xlim[1]) + for k in range(0, i): + axes[i][k].set_ylim(xlim[0], xlim[1]) def plot_prior_posterior(self, normal_stds=2): """ Plot the posterior in the context of the prior """