Commit 7b99ae40 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improves plot_walkers

1) Colors the output depending on if it is burn-in or production
2) Adds another panel with the twoF values of all walkers (also colored
by bunr-in or prod)
3) Also adds more logging to writing par files
parent 48b0cd0e
......@@ -728,7 +728,8 @@ class MCMCSearch(BaseSearchClass):
logging.info("Tswap acceptance fraction: {}"
.format(sampler.tswap_acceptance_fraction))
fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols)
fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
burnin_idx=nburn)
fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label))
samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
......@@ -885,7 +886,7 @@ class MCMCSearch(BaseSearchClass):
raise ValueError("dist_type {} unknown".format(dist_type))
def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
start=None, stop=None, draw_vline=None):
burnin_idx=None):
""" Plot all the chains from a sampler """
shape = sampler.chain.shape
......@@ -902,23 +903,37 @@ class MCMCSearch(BaseSearchClass):
chain = sampler.chain[temp, :, :, :]
with plt.style.context(('classic')):
fig, axes = plt.subplots(ndim, 1, sharex=True, figsize=(8, 4*ndim))
fig = plt.figure(figsize=(8, 4*ndim))
ax = fig.add_subplot(ndim+1, 1, 1)
axes = [ax] + [fig.add_subplot(ndim+1, 1, i, sharex=ax)
for i in range(2, ndim+1)]
idxs = np.arange(chain.shape[1])
if ndim > 1:
for i in range(ndim):
axes[i].ticklabel_format(useOffset=False, axis='y')
cs = chain[:, start:stop, i].T
axes[i].plot(cs, color="k", alpha=alpha)
cs = chain[:, :, i].T
if burnin_idx:
axes[i].plot(idxs[:burnin_idx], cs[:burnin_idx],
color="r", alpha=alpha)
axes[i].plot(idxs[burnin_idx:], cs[burnin_idx:], color="k",
alpha=alpha)
if symbols:
axes[i].set_ylabel(symbols[i])
if draw_vline is not None:
axes[i].axvline(draw_vline, lw=2, ls="--")
else:
cs = chain[:, start:stop, 0].T
cs = chain[:, :, temp].T
axes.plot(cs, color='k', alpha=alpha)
axes.ticklabel_format(useOffset=False, axis='y')
axes.append(fig.add_subplot(ndim+1, 1, ndim+1))
lnl = sampler.lnlikelihood[temp, :, :]
if burnin_idx:
axes[-1].hist(lnl[:, :burnin_idx].flatten(), bins=50, histtype='step',
color='r')
axes[-1].hist(lnl[:, burnin_idx:].flatten(), bins=50, histtype='step',
color='k')
axes[-1].set_xlabel(r'$2\mathcal{F}$')
return fig, axes
def apply_corrections_to_p0(self, p0):
......@@ -1148,6 +1163,7 @@ class MCMCSearch(BaseSearchClass):
median_std_d = self.get_median_stds()
max_twoF_d, max_twoF = self.get_max_twoF()
logging.info('Writing par file with max twoF = {}'.format(max_twoF))
filename = '{}/{}.par'.format(self.outdir, self.label)
with open(filename, 'w+') as f:
f.write('MaxtwoF = {}\n'.format(max_twoF))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment