diff --git a/pyfstat.py b/pyfstat.py
index 91b0e0991e98b9efc44662d865a4c48cbdc4cbad..a38b0a421330efb3824d4b7e976f515dc472c51a 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -821,6 +821,48 @@ class MCMCSearch(BaseSearchClass):
             ax2.plot(x, [prior(xi) for xi in x], '-r')
             ax.set_xlim(xlim)
 
+    def plot_prior_posterior(self, normal_stds=2):
+        """ Plot the posterior in the context of the prior """
+        fig, axes = plt.subplots(nrows=self.ndim, figsize=(8, 4*self.ndim))
+        N = 1000
+        from scipy.stats import gaussian_kde
+
+        for i, (ax, key) in enumerate(zip(axes, self.theta_keys)):
+            prior_dict = self.theta_prior[key]
+            prior_func = self.generic_lnprior(**prior_dict)
+            if prior_dict['type'] == 'unif':
+                x = np.linspace(prior_dict['lower'], prior_dict['upper'], N)
+                prior = prior_func(x)
+                prior[0] = 0
+                prior[-1] = 0
+            elif prior_dict['type'] == 'norm':
+                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
+                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
+                x = np.linspace(lower, upper, N)
+                prior = prior_func(x)
+            else:
+                raise ValueError('Not implemented for prior type {}'.format(
+                    prior_dict['type']))
+            priorln = ax.plot(x, prior, 'r', label='prior')
+            ax.set_xlabel(self.theta_symbols[i])
+
+            s = self.samples[:, i]
+            while len(s) > 10**4:
+                # random downsample to avoid slow calculation of kde
+                s = np.random.choice(s, size=int(len(s)/2.))
+            kde = gaussian_kde(s)
+            ax2 = ax.twinx()
+            postln = ax2.plot(x, kde.pdf(x), 'k', label='posterior')
+            ax2.set_yticklabels([])
+            ax.set_yticklabels([])
+
+        lns = priorln + postln
+        labs = [l.get_label() for l in lns]
+        axes[0].legend(lns, labs, loc=1, framealpha=0.8)
+
+        fig.savefig('{}/{}_prior_posterior.png'.format(
+            self.outdir, self.label))
+
     def generic_lnprior(self, **kwargs):
         """ Return a lambda function of the pdf