From cfc9d8fe477dc87f7a209a6ee036cd0ed8a0415d Mon Sep 17 00:00:00 2001
From: "gregory.ashton" <gregory.ashton@ligo.org>
Date: Fri, 12 May 2017 12:20:33 +0200
Subject: [PATCH] Improves add_prior functionality

- Uses the actual prior and not the log-prior as previous
- Adds option to plot the full extent (currently only for uniform
  priors)
- ensures the xlims and ylims are reset correctly
---
 pyfstat/mcmc_based_searches.py | 41 ++++++++++++++++++++++++----------
 1 file changed, 29 insertions(+), 12 deletions(-)

diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index 6000237..9bb1c22 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -524,8 +524,9 @@ class MCMCSearch(core.BaseSearchClass):
         ----------
         figsize: tuple (7, 7)
             Figure size in inches (passed to plt.subplots)
-        add_prior: bool
-            If true, plot the prior as a red line
+        add_prior: bool, str
+            If true, plot the prior as a red line. If 'full' then for uniform
+            priors plot the full extent of the prior.
         nstds: float
             The number of standard deviations to plot centered on the mean
         label_offset: float
@@ -546,7 +547,7 @@ class MCMCSearch(core.BaseSearchClass):
         save_fig: bool
             If true, save the figure, else return the fig, axes
 
-        Note: kwargs are passed on to corner.coner
+        Note: kwargs are passed on to corner.corner
 
         """
 
@@ -599,6 +600,10 @@ class MCMCSearch(core.BaseSearchClass):
             else:
                 _range = None
 
+            hist_kwargs = kwargs.pop('hist_kwargs', dict())
+            if 'normed' not in hist_kwargs:
+                hist_kwargs['normed'] = True
+
             fig_triangle = corner.corner(samples_plt,
                                          labels=labels,
                                          fig=fig,
@@ -610,6 +615,7 @@ class MCMCSearch(core.BaseSearchClass):
                                          data_kwargs={'alpha': 0.1,
                                                       'ms': 0.5},
                                          range=_range,
+                                         hist_kwargs=hist_kwargs,
                                          **kwargs)
 
             axes_list = fig_triangle.get_axes()
@@ -626,7 +632,7 @@ class MCMCSearch(core.BaseSearchClass):
             fig.subplots_adjust(hspace=0.05, wspace=0.05)
 
             if add_prior:
-                self._add_prior_to_corner(axes, self.samples)
+                self._add_prior_to_corner(axes, self.samples, add_prior)
 
             if save_fig:
                 fig_triangle.savefig('{}/{}_corner.png'.format(
@@ -634,19 +640,30 @@ class MCMCSearch(core.BaseSearchClass):
             else:
                 return fig, axes
 
-    def _add_prior_to_corner(self, axes, samples):
+    def _add_prior_to_corner(self, axes, samples, add_prior):
         for i, key in enumerate(self.theta_keys):
             ax = axes[i][i]
-            xlim = ax.get_xlim()
             s = samples[:, i]
-            prior = self._generic_lnprior(**self.theta_prior[key])
-            x = np.linspace(s.min(), s.max(), 100)
+            lnprior = self._generic_lnprior(**self.theta_prior[key])
+            if add_prior == 'full' and self.theta_prior[key]['type'] == 'unif':
+                lower = self.theta_prior[key]['lower']
+                upper = self.theta_prior[key]['upper']
+                r = upper-lower
+                xlim = [lower-0.05*r, upper+0.05*r]
+                x = np.linspace(xlim[0], xlim[1], 1000)
+            else:
+                xlim = ax.get_xlim()
+                x = np.linspace(s.min(), s.max(), 1000)
             multiplier = self._get_rescale_multiplier_for_key(key)
             subtractor = self._get_rescale_subtractor_for_key(key)
-            ax2 = ax.twinx()
-            ax2.get_yaxis().set_visible(False)
-            ax2.plot((x-subtractor)*multiplier, [prior(xi) for xi in x], '-r')
-            ax2.set_xlim(xlim)
+            ax.plot((x-subtractor)*multiplier,
+                    [np.exp(lnprior(xi)) for xi in x], '-C3',
+                    label='prior')
+
+            for j in range(i, self.ndim):
+                axes[j][i].set_xlim(xlim[0], xlim[1])
+            for k in range(0, i):
+                axes[i][k].set_ylim(xlim[0], xlim[1])
 
     def plot_prior_posterior(self, normal_stds=2):
         """ Plot the posterior in the context of the prior """
-- 
GitLab