diff --git a/pyfstat.py b/pyfstat.py index c7124ec12e8cd9ac60ce22dda43a57bd24acd6dd..decacc5176e5ce87317f755b89efe7303f62e2e8 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -728,7 +728,8 @@ class MCMCSearch(BaseSearchClass): logging.info("Tswap acceptance fraction: {}" .format(sampler.tswap_acceptance_fraction)) - fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols) + fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols, + burnin_idx=nburn) fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label)) samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim)) @@ -885,7 +886,7 @@ class MCMCSearch(BaseSearchClass): raise ValueError("dist_type {} unknown".format(dist_type)) def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0, - start=None, stop=None, draw_vline=None): + burnin_idx=None): """ Plot all the chains from a sampler """ shape = sampler.chain.shape @@ -902,23 +903,37 @@ class MCMCSearch(BaseSearchClass): chain = sampler.chain[temp, :, :, :] with plt.style.context(('classic')): - fig, axes = plt.subplots(ndim, 1, sharex=True, figsize=(8, 4*ndim)) + fig = plt.figure(figsize=(8, 4*ndim)) + ax = fig.add_subplot(ndim+1, 1, 1) + axes = [ax] + [fig.add_subplot(ndim+1, 1, i, sharex=ax) + for i in range(2, ndim+1)] + idxs = np.arange(chain.shape[1]) if ndim > 1: for i in range(ndim): axes[i].ticklabel_format(useOffset=False, axis='y') - cs = chain[:, start:stop, i].T - axes[i].plot(cs, color="k", alpha=alpha) + cs = chain[:, :, i].T + if burnin_idx: + axes[i].plot(idxs[:burnin_idx], cs[:burnin_idx], + color="r", alpha=alpha) + axes[i].plot(idxs[burnin_idx:], cs[burnin_idx:], color="k", + alpha=alpha) if symbols: axes[i].set_ylabel(symbols[i]) - if draw_vline is not None: - axes[i].axvline(draw_vline, lw=2, ls="--") - else: - cs = chain[:, start:stop, 0].T + cs = chain[:, :, temp].T axes.plot(cs, color='k', alpha=alpha) axes.ticklabel_format(useOffset=False, axis='y') + axes.append(fig.add_subplot(ndim+1, 1, ndim+1)) + lnl = sampler.lnlikelihood[temp, :, :] + if burnin_idx: + axes[-1].hist(lnl[:, :burnin_idx].flatten(), bins=50, histtype='step', + color='r') + axes[-1].hist(lnl[:, burnin_idx:].flatten(), bins=50, histtype='step', + color='k') + axes[-1].set_xlabel(r'$2\mathcal{F}$') + return fig, axes def apply_corrections_to_p0(self, p0): @@ -1148,6 +1163,7 @@ class MCMCSearch(BaseSearchClass): median_std_d = self.get_median_stds() max_twoF_d, max_twoF = self.get_max_twoF() + logging.info('Writing par file with max twoF = {}'.format(max_twoF)) filename = '{}/{}.par'.format(self.outdir, self.label) with open(filename, 'w+') as f: f.write('MaxtwoF = {}\n'.format(max_twoF))