Commit 4fb7886e authored by Gregory Ashton's avatar Gregory Ashton

Add chain consumer and update example

parent fe49dcba
......@@ -5,7 +5,7 @@ import gridcorner
import time
from make_simulated_data import tstart, duration, tref, F0, F1, F2, Alpha, Delta, delta_F0, dtglitch, outdir
plt.style.use('./paper.mplstyle')
#plt.style.use('./paper.mplstyle')
label = 'semicoherent_glitch_robust_directed_MCMC_search_on_1_glitch'
......@@ -60,7 +60,9 @@ mcmc.run()
dT = time.time() - t1
fig_and_axes = gridcorner._get_fig_and_axes(4, 2, 0.05)
mcmc.plot_corner(label_offset=0.25, truths=[0, 0, 0, 0],
fig_and_axes=fig_and_axes)
fig_and_axes=fig_and_axes, quantiles=(0.16, 0.84))
#mcmc.plot_chainconsumer(truth=[0, 0, 0, 0], label_offset=0.5)
mcmc.print_summary()
print('Prior widths =', F0_width, F1_width)
......
......@@ -787,6 +787,61 @@ class MCMCSearch(core.BaseSearchClass):
else:
return fig, axes
def plot_chainconsumer(
self, save_fig=True, label_offset=0.25, dpi=300, **kwargs):
""" Generate a corner plot of the posterior using chainconsumer
Parameters
----------
dpi: int
Passed to plt.savefig
**kwargs:
Passed to chainconsumer.plotter.plot
"""
if 'truths' in kwargs and len(kwargs['truths']) != self.ndim:
logging.warning('len(Truths) != ndim, Truths will be ignored')
kwargs['truths'] = None
samples_plt = copy.copy(self.samples)
labels = self._get_labels(newline_units=True)
samples_plt = self._scale_samples(samples_plt, self.theta_keys)
import chainconsumer
c = chainconsumer.ChainConsumer()
c.add_chain(samples_plt, parameters=labels)
c.configure(smooth=0, summary=False, sigma2d=True)
fig = c.plotter.plot(**kwargs)
axes_list = fig.get_axes()
axes = np.array(axes_list).reshape(self.ndim, self.ndim)
plt.draw()
for ax in axes[:, 0]:
ax.yaxis.set_label_coords(-label_offset, 0.5)
for ax in axes[-1, :]:
ax.xaxis.set_label_coords(0.5, -label_offset)
for ax in axes_list:
ax.set_rasterized(True)
ax.set_rasterization_zorder(-10)
#for tick in ax.xaxis.get_major_ticks():
# #tick.label.set_fontsize(8)
# tick.label.set_rotation('horizontal')
#for tick in ax.yaxis.get_major_ticks():
# #tick.label.set_fontsize(8)
# tick.label.set_rotation('vertical')
plt.tight_layout(h_pad=0.0, w_pad=0.0)
fig.subplots_adjust(hspace=0.05, wspace=0.05)
if save_fig:
fig.savefig('{}/{}_corner.png'.format(
self.outdir, self.label), dpi=dpi)
else:
return fig
def _add_prior_to_corner(self, axes, samples, add_prior):
for i, key in enumerate(self.theta_keys):
ax = axes[i][i]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment