diff --git a/pyfstat.py b/pyfstat.py index 91b0e0991e98b9efc44662d865a4c48cbdc4cbad..a38b0a421330efb3824d4b7e976f515dc472c51a 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -821,6 +821,48 @@ class MCMCSearch(BaseSearchClass): ax2.plot(x, [prior(xi) for xi in x], '-r') ax.set_xlim(xlim) + def plot_prior_posterior(self, normal_stds=2): + """ Plot the posterior in the context of the prior """ + fig, axes = plt.subplots(nrows=self.ndim, figsize=(8, 4*self.ndim)) + N = 1000 + from scipy.stats import gaussian_kde + + for i, (ax, key) in enumerate(zip(axes, self.theta_keys)): + prior_dict = self.theta_prior[key] + prior_func = self.generic_lnprior(**prior_dict) + if prior_dict['type'] == 'unif': + x = np.linspace(prior_dict['lower'], prior_dict['upper'], N) + prior = prior_func(x) + prior[0] = 0 + prior[-1] = 0 + elif prior_dict['type'] == 'norm': + lower = prior_dict['loc'] - normal_stds * prior_dict['scale'] + upper = prior_dict['loc'] + normal_stds * prior_dict['scale'] + x = np.linspace(lower, upper, N) + prior = prior_func(x) + else: + raise ValueError('Not implemented for prior type {}'.format( + prior_dict['type'])) + priorln = ax.plot(x, prior, 'r', label='prior') + ax.set_xlabel(self.theta_symbols[i]) + + s = self.samples[:, i] + while len(s) > 10**4: + # random downsample to avoid slow calculation of kde + s = np.random.choice(s, size=int(len(s)/2.)) + kde = gaussian_kde(s) + ax2 = ax.twinx() + postln = ax2.plot(x, kde.pdf(x), 'k', label='posterior') + ax2.set_yticklabels([]) + ax.set_yticklabels([]) + + lns = priorln + postln + labs = [l.get_label() for l in lns] + axes[0].legend(lns, labs, loc=1, framealpha=0.8) + + fig.savefig('{}/{}_prior_posterior.png'.format( + self.outdir, self.label)) + def generic_lnprior(self, **kwargs): """ Return a lambda function of the pdf