Commit cfc9d8fe authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improves add_prior functionality

- Uses the actual prior and not the log-prior as previous
- Adds option to plot the full extent (currently only for uniform
  priors)
- ensures the xlims and ylims are reset correctly
parent ed1b32c1
...@@ -524,8 +524,9 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -524,8 +524,9 @@ class MCMCSearch(core.BaseSearchClass):
---------- ----------
figsize: tuple (7, 7) figsize: tuple (7, 7)
Figure size in inches (passed to plt.subplots) Figure size in inches (passed to plt.subplots)
add_prior: bool add_prior: bool, str
If true, plot the prior as a red line If true, plot the prior as a red line. If 'full' then for uniform
priors plot the full extent of the prior.
nstds: float nstds: float
The number of standard deviations to plot centered on the mean The number of standard deviations to plot centered on the mean
label_offset: float label_offset: float
...@@ -546,7 +547,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -546,7 +547,7 @@ class MCMCSearch(core.BaseSearchClass):
save_fig: bool save_fig: bool
If true, save the figure, else return the fig, axes If true, save the figure, else return the fig, axes
Note: kwargs are passed on to corner.coner Note: kwargs are passed on to corner.corner
""" """
...@@ -599,6 +600,10 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -599,6 +600,10 @@ class MCMCSearch(core.BaseSearchClass):
else: else:
_range = None _range = None
hist_kwargs = kwargs.pop('hist_kwargs', dict())
if 'normed' not in hist_kwargs:
hist_kwargs['normed'] = True
fig_triangle = corner.corner(samples_plt, fig_triangle = corner.corner(samples_plt,
labels=labels, labels=labels,
fig=fig, fig=fig,
...@@ -610,6 +615,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -610,6 +615,7 @@ class MCMCSearch(core.BaseSearchClass):
data_kwargs={'alpha': 0.1, data_kwargs={'alpha': 0.1,
'ms': 0.5}, 'ms': 0.5},
range=_range, range=_range,
hist_kwargs=hist_kwargs,
**kwargs) **kwargs)
axes_list = fig_triangle.get_axes() axes_list = fig_triangle.get_axes()
...@@ -626,7 +632,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -626,7 +632,7 @@ class MCMCSearch(core.BaseSearchClass):
fig.subplots_adjust(hspace=0.05, wspace=0.05) fig.subplots_adjust(hspace=0.05, wspace=0.05)
if add_prior: if add_prior:
self._add_prior_to_corner(axes, self.samples) self._add_prior_to_corner(axes, self.samples, add_prior)
if save_fig: if save_fig:
fig_triangle.savefig('{}/{}_corner.png'.format( fig_triangle.savefig('{}/{}_corner.png'.format(
...@@ -634,19 +640,30 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -634,19 +640,30 @@ class MCMCSearch(core.BaseSearchClass):
else: else:
return fig, axes return fig, axes
def _add_prior_to_corner(self, axes, samples): def _add_prior_to_corner(self, axes, samples, add_prior):
for i, key in enumerate(self.theta_keys): for i, key in enumerate(self.theta_keys):
ax = axes[i][i] ax = axes[i][i]
xlim = ax.get_xlim()
s = samples[:, i] s = samples[:, i]
prior = self._generic_lnprior(**self.theta_prior[key]) lnprior = self._generic_lnprior(**self.theta_prior[key])
x = np.linspace(s.min(), s.max(), 100) if add_prior == 'full' and self.theta_prior[key]['type'] == 'unif':
lower = self.theta_prior[key]['lower']
upper = self.theta_prior[key]['upper']
r = upper-lower
xlim = [lower-0.05*r, upper+0.05*r]
x = np.linspace(xlim[0], xlim[1], 1000)
else:
xlim = ax.get_xlim()
x = np.linspace(s.min(), s.max(), 1000)
multiplier = self._get_rescale_multiplier_for_key(key) multiplier = self._get_rescale_multiplier_for_key(key)
subtractor = self._get_rescale_subtractor_for_key(key) subtractor = self._get_rescale_subtractor_for_key(key)
ax2 = ax.twinx() ax.plot((x-subtractor)*multiplier,
ax2.get_yaxis().set_visible(False) [np.exp(lnprior(xi)) for xi in x], '-C3',
ax2.plot((x-subtractor)*multiplier, [prior(xi) for xi in x], '-r') label='prior')
ax2.set_xlim(xlim)
for j in range(i, self.ndim):
axes[j][i].set_xlim(xlim[0], xlim[1])
for k in range(0, i):
axes[i][k].set_ylim(xlim[0], xlim[1])
def plot_prior_posterior(self, normal_stds=2): def plot_prior_posterior(self, normal_stds=2):
""" Plot the posterior in the context of the prior """ """ Plot the posterior in the context of the prior """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment