diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index e499d5caaa71e65d4e9296d162d24f2ee3143b6f..7d8edc170b18138297c2d97d6d6809e17d3d5a25 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -507,7 +507,8 @@ class MCMCSearch(core.BaseSearchClass): def plot_corner(self, figsize=(7, 7), add_prior=False, nstds=None, label_offset=0.4, dpi=300, rc_context={}, - tglitch_ratio=False, **kwargs): + tglitch_ratio=False, fig_and_axes=None, save_fig=False, + **kwargs): """ Generate a corner plot of the posterior Using the `corner` package (https://pypi.python.org/pypi/corner/), @@ -533,6 +534,11 @@ class MCMCSearch(core.BaseSearchClass): If true, and tglitch is a parameter, plot posteriors as the fractional time at which the glitch occurs instead of the actual time + fig_and_axes: tuple + fig and axes to plot on, the axes must be of the right shape, + namely (ndim, ndim) + save_fig: bool + If true, save the figure, else return the fig, axes Note: kwargs are passed on to corner.coner @@ -540,7 +546,10 @@ class MCMCSearch(core.BaseSearchClass): if self.ndim < 2: with plt.rc_context(rc_context): - fig, ax = plt.subplots(figsize=figsize) + if fig_and_axes is None: + fig, ax = plt.subplots(figsize=figsize) + else: + fig, ax = fig_and_axes ax.hist(self.samples, bins=50, histtype='stepfilled') ax.set_xlabel(self.theta_symbols[0]) @@ -549,8 +558,11 @@ class MCMCSearch(core.BaseSearchClass): return with plt.rc_context(rc_context): - fig, axes = plt.subplots(self.ndim, self.ndim, - figsize=figsize) + if fig_and_axes is None: + fig, axes = plt.subplots(self.ndim, self.ndim, + figsize=figsize) + else: + fig, axes = fig_and_axes samples_plt = copy.copy(self.samples) labels = self._get_labels() @@ -572,6 +584,8 @@ class MCMCSearch(core.BaseSearchClass): median = np.median(s) std = np.std(s) _range.append((median - nstds*std, median + nstds*std)) + elif 'range' in kwargs: + _range = kwargs.pop('range') else: _range = None @@ -604,8 +618,11 @@ class MCMCSearch(core.BaseSearchClass): if add_prior: self._add_prior_to_corner(axes, self.samples) - fig_triangle.savefig('{}/{}_corner.png'.format( - self.outdir, self.label), dpi=dpi) + if save_fig: + fig_triangle.savefig('{}/{}_corner.png'.format( + self.outdir, self.label), dpi=dpi) + else: + return fig, axes def _add_prior_to_corner(self, axes, samples): for i, key in enumerate(self.theta_keys):