diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index e499d5caaa71e65d4e9296d162d24f2ee3143b6f..7d8edc170b18138297c2d97d6d6809e17d3d5a25 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -507,7 +507,8 @@ class MCMCSearch(core.BaseSearchClass):
 
     def plot_corner(self, figsize=(7, 7), add_prior=False, nstds=None,
                     label_offset=0.4, dpi=300, rc_context={},
-                    tglitch_ratio=False, **kwargs):
+                    tglitch_ratio=False, fig_and_axes=None, save_fig=False,
+                    **kwargs):
         """ Generate a corner plot of the posterior
 
         Using the `corner` package (https://pypi.python.org/pypi/corner/),
@@ -533,6 +534,11 @@ class MCMCSearch(core.BaseSearchClass):
             If true, and tglitch is a parameter, plot posteriors as the
             fractional time at which the glitch occurs instead of the actual
             time
+        fig_and_axes: tuple
+            fig and axes to plot on, the axes must be of the right shape,
+            namely (ndim, ndim)
+        save_fig: bool
+            If true, save the figure, else return the fig, axes
 
         Note: kwargs are passed on to corner.coner
 
@@ -540,7 +546,10 @@ class MCMCSearch(core.BaseSearchClass):
 
         if self.ndim < 2:
             with plt.rc_context(rc_context):
-                fig, ax = plt.subplots(figsize=figsize)
+                if fig_and_axes is None:
+                    fig, ax = plt.subplots(figsize=figsize)
+                else:
+                    fig, ax = fig_and_axes
                 ax.hist(self.samples, bins=50, histtype='stepfilled')
                 ax.set_xlabel(self.theta_symbols[0])
 
@@ -549,8 +558,11 @@ class MCMCSearch(core.BaseSearchClass):
             return
 
         with plt.rc_context(rc_context):
-            fig, axes = plt.subplots(self.ndim, self.ndim,
-                                     figsize=figsize)
+            if fig_and_axes is None:
+                fig, axes = plt.subplots(self.ndim, self.ndim,
+                                         figsize=figsize)
+            else:
+                fig, axes = fig_and_axes
 
             samples_plt = copy.copy(self.samples)
             labels = self._get_labels()
@@ -572,6 +584,8 @@ class MCMCSearch(core.BaseSearchClass):
                     median = np.median(s)
                     std = np.std(s)
                     _range.append((median - nstds*std, median + nstds*std))
+            elif 'range' in kwargs:
+                _range = kwargs.pop('range')
             else:
                 _range = None
 
@@ -604,8 +618,11 @@ class MCMCSearch(core.BaseSearchClass):
             if add_prior:
                 self._add_prior_to_corner(axes, self.samples)
 
-            fig_triangle.savefig('{}/{}_corner.png'.format(
-                self.outdir, self.label), dpi=dpi)
+            if save_fig:
+                fig_triangle.savefig('{}/{}_corner.png'.format(
+                    self.outdir, self.label), dpi=dpi)
+            else:
+                return fig, axes
 
     def _add_prior_to_corner(self, axes, samples):
         for i, key in enumerate(self.theta_keys):