Commit cfc9d8fe authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improves add_prior functionality

- Uses the actual prior and not the log-prior as previous
- Adds option to plot the full extent (currently only for uniform
  priors)
- ensures the xlims and ylims are reset correctly
parent ed1b32c1
......@@ -524,8 +524,9 @@ class MCMCSearch(core.BaseSearchClass):
----------
figsize: tuple (7, 7)
Figure size in inches (passed to plt.subplots)
add_prior: bool
If true, plot the prior as a red line
add_prior: bool, str
If true, plot the prior as a red line. If 'full' then for uniform
priors plot the full extent of the prior.
nstds: float
The number of standard deviations to plot centered on the mean
label_offset: float
......@@ -546,7 +547,7 @@ class MCMCSearch(core.BaseSearchClass):
save_fig: bool
If true, save the figure, else return the fig, axes
Note: kwargs are passed on to corner.coner
Note: kwargs are passed on to corner.corner
"""
......@@ -599,6 +600,10 @@ class MCMCSearch(core.BaseSearchClass):
else:
_range = None
hist_kwargs = kwargs.pop('hist_kwargs', dict())
if 'normed' not in hist_kwargs:
hist_kwargs['normed'] = True
fig_triangle = corner.corner(samples_plt,
labels=labels,
fig=fig,
......@@ -610,6 +615,7 @@ class MCMCSearch(core.BaseSearchClass):
data_kwargs={'alpha': 0.1,
'ms': 0.5},
range=_range,
hist_kwargs=hist_kwargs,
**kwargs)
axes_list = fig_triangle.get_axes()
......@@ -626,7 +632,7 @@ class MCMCSearch(core.BaseSearchClass):
fig.subplots_adjust(hspace=0.05, wspace=0.05)
if add_prior:
self._add_prior_to_corner(axes, self.samples)
self._add_prior_to_corner(axes, self.samples, add_prior)
if save_fig:
fig_triangle.savefig('{}/{}_corner.png'.format(
......@@ -634,19 +640,30 @@ class MCMCSearch(core.BaseSearchClass):
else:
return fig, axes
def _add_prior_to_corner(self, axes, samples):
def _add_prior_to_corner(self, axes, samples, add_prior):
for i, key in enumerate(self.theta_keys):
ax = axes[i][i]
xlim = ax.get_xlim()
s = samples[:, i]
prior = self._generic_lnprior(**self.theta_prior[key])
x = np.linspace(s.min(), s.max(), 100)
lnprior = self._generic_lnprior(**self.theta_prior[key])
if add_prior == 'full' and self.theta_prior[key]['type'] == 'unif':
lower = self.theta_prior[key]['lower']
upper = self.theta_prior[key]['upper']
r = upper-lower
xlim = [lower-0.05*r, upper+0.05*r]
x = np.linspace(xlim[0], xlim[1], 1000)
else:
xlim = ax.get_xlim()
x = np.linspace(s.min(), s.max(), 1000)
multiplier = self._get_rescale_multiplier_for_key(key)
subtractor = self._get_rescale_subtractor_for_key(key)
ax2 = ax.twinx()
ax2.get_yaxis().set_visible(False)
ax2.plot((x-subtractor)*multiplier, [prior(xi) for xi in x], '-r')
ax2.set_xlim(xlim)
ax.plot((x-subtractor)*multiplier,
[np.exp(lnprior(xi)) for xi in x], '-C3',
label='prior')
for j in range(i, self.ndim):
axes[j][i].set_xlim(xlim[0], xlim[1])
for k in range(0, i):
axes[i][k].set_ylim(xlim[0], xlim[1])
def plot_prior_posterior(self, normal_stds=2):
""" Plot the posterior in the context of the prior """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment