diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py
index d231961df1465f14e231c479f78b305695a1a50b..a7eebcc942a20e72259b7fea41e7b406ecc15089 100644
--- a/pyfstat/grid_based_searches.py
+++ b/pyfstat/grid_based_searches.py
@@ -20,12 +20,17 @@ import lal
 
 class GridSearch(BaseSearchClass):
     """ Gridded search using ComputeFstat """
+    tex_labels = {'F0': '$f$', 'F1': '$\dot{f}$', 'F2': '$\ddot{f}$',
+                  'Alpha': r'$\alpha$', 'Delta': r'$\delta$'}
+    tex_labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$', 'F2': '$-\ddot{f}_0$',
+                   'Alpha': r'$-\alpha_0$', 'Deltas': r'$-\delta_0$'}
+
     @helper_functions.initializer
-    def __init__(self, label, outdir, sftfilepattern, F0s=[0], F1s=[0], F2s=[0],
-                 Alphas=[0], Deltas=[0], tref=None, minStartTime=None,
-                 maxStartTime=None, nsegs=1, BSGL=False, minCoverFreq=None,
-                 maxCoverFreq=None, detectors=None, SSBprec=None,
-                 injectSources=None, input_arrays=False, assumeSqrtSX=None):
+    def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
+                 Deltas, tref=None, minStartTime=None, maxStartTime=None,
+                 nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None,
+                 detectors=None, SSBprec=None, injectSources=None,
+                 input_arrays=False, assumeSqrtSX=None):
         """
         Parameters
         ----------
@@ -50,6 +55,9 @@ class GridSearch(BaseSearchClass):
             os.mkdir(outdir)
         self.set_out_file()
         self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
+        self.search_keys = [x+'s' for x in self.keys[2:]]
+        for k in self.search_keys:
+            setattr(self, k, np.atleast_1d(getattr(self, k)))
 
     def inititate_search_object(self):
         logging.info('Setting up search object')
@@ -88,8 +96,8 @@ class GridSearch(BaseSearchClass):
     def get_input_data_array(self):
         logging.info("Generating input data array")
         coord_arrays = []
-        for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s,
-                    self.Alphas, self.Deltas):
+        for tup in ([self.minStartTime], [self.maxStartTime], self.F0s,
+                    self.F1s, self.F2s, self.Alphas, self.Deltas):
             coord_arrays.append(self.get_array_from_tuple(tup))
 
         self.input_data_generator_len = np.prod([len(k) for k in coord_arrays])
@@ -112,15 +120,6 @@ class GridSearch(BaseSearchClass):
 
         logging.info('No data caching available')
         return False
-        #data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
-        #if np.all(data[:, 0:-1] == self.input_data):
-        #    logging.info(
-        #        'Old data found with matching input, no search performed')
-        #    return data
-        #else:
-        #    logging.info(
-        #        'Old data found, input differs, continuing with grid search')
-        #    return False
 
     def run(self, return_data=False):
         self.get_input_data_array()
@@ -129,10 +128,12 @@ class GridSearch(BaseSearchClass):
             self.data = old_data
             return
 
-        self.inititate_search_object()
+        if hasattr(self, 'search') is False:
+            self.inititate_search_object()
 
         data = []
-        for vals in tqdm(self.input_data_generator):
+        for vals in tqdm(self.input_data_generator,
+                         total=self.input_data_generator_len):
             FS = self.search.get_det_stat(*vals)
             data.append(list(vals) + [FS])
 
@@ -170,18 +171,27 @@ class GridSearch(BaseSearchClass):
             m = self.convert_F1_to_mismatch(y, yhat, Tseg)
             axY.set_ylim(m[0], m[-1])
 
-    def plot_1D(self, xkey):
-        fig, ax = plt.subplots()
+    def plot_1D(self, xkey, ax=None, x0=None, savefig=True):
+        if ax is None:
+            fig, ax = plt.subplots()
         xidx = self.keys.index(xkey)
         x = np.unique(self.data[:, xidx])
         z = self.data[:, -1]
-        plt.plot(x, z)
-        fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
+        ax.plot(x, z)
+        if x0:
+            ax.set_xlabel(self.tex_labels[xkey]+self.tex_labels0[xkey])
+        else:
+            ax.set_xlabel(self.tex_labels[xkey])
+        if savefig:
+            fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
+        else:
+            return ax
 
     def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None,
                 add_mismatch=None, xN=None, yN=None, flat_keys=[],
                 rel_flat_idxs=[], flatten_method=np.max, title=None,
-                predicted_twoF=None, cm=None, cbarkwargs={}, x0=None, y0=None):
+                predicted_twoF=None, cm=None, cbarkwargs={}, x0=None, y0=None,
+                colorbar=False):
         """ Plots a 2D grid of 2F values
 
         Parameters
@@ -223,24 +233,23 @@ class GridSearch(BaseSearchClass):
                 cm = plt.cm.viridis
 
         pax = ax.pcolormesh(X, Y, Z, cmap=cm, vmin=vmin, vmax=vmax)
-        cb = plt.colorbar(pax, ax=ax, **cbarkwargs)
-        cb.set_label('$2\mathcal{F}$')
+        if colorbar:
+            cb = plt.colorbar(pax, ax=ax, **cbarkwargs)
+            cb.set_label('$2\mathcal{F}$')
 
         if add_mismatch:
             self.add_mismatch_to_ax(ax, x, y, xkey, ykey, *add_mismatch)
 
         ax.set_xlim(x[0], x[-1])
         ax.set_ylim(y[0], y[-1])
-        labels = {'F0': '$f$', 'F1': '$\dot{f}$'}
-        labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$'}
         if x0:
-            ax.set_xlabel(labels[xkey]+labels0[xkey])
+            ax.set_xlabel(self.tex_labels[xkey]+self.tex_labels0[xkey])
         else:
-            ax.set_xlabel(labels[xkey])
+            ax.set_xlabel(self.tex_labels[xkey])
         if y0:
-            ax.set_ylabel(labels[ykey]+labels0[ykey])
+            ax.set_ylabel(self.tex_labels[ykey]+self.tex_labels0[ykey])
         else:
-            ax.set_ylabel(labels[ykey])
+            ax.set_ylabel(self.tex_labels[ykey])
 
         if title:
             ax.set_title(title)
@@ -287,12 +296,11 @@ class GridSearch(BaseSearchClass):
 class SliceGridSearch(GridSearch):
     """ Slice gridded search using ComputeFstat """
     @helper_functions.initializer
-    def __init__(self, label, outdir, sftfilepattern, F0s=[0], F1s=[0], F2s=[0],
-                 Alphas=[0], Deltas=[0], tref=None, minStartTime=None,
-                 maxStartTime=None, nsegs=1, BSGL=False, minCoverFreq=None,
-                 maxCoverFreq=None, detectors=None, SSBprec=None,
-                 injectSources=None, input_arrays=False, assumeSqrtSX=None,
-                 Lambda0=None):
+    def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
+                 Deltas, tref=None, minStartTime=None, maxStartTime=None,
+                 nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None,
+                 detectors=None, SSBprec=None, injectSources=None,
+                 input_arrays=False, assumeSqrtSX=None, Lambda0=None):
         """
         Parameters
         ----------
@@ -317,46 +325,76 @@ class SliceGridSearch(GridSearch):
             os.mkdir(outdir)
         self.set_out_file()
         self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
+        self.ndim = 0
+        self.thetas = [F0s, F1s, Alphas, Deltas]
+        self.ndim = 4
 
+        self.search_keys = ['F0', 'F1', 'Alpha', 'Delta']
         self.Lambda0 = np.array(Lambda0)
-        if len(self.Lambda0) != len(self.keys):
+        if len(self.Lambda0) != len(self.search_keys):
             raise ValueError(
-                'Lambda0 must be of length {}'.format(len(self.keys)))
-
-    def run(self, return_data=False):
-        self.get_input_data_array()
-
-        self.Lambda0s_grid = []
-        for j, arr in enumerate(self.coord_arrays):
-            i = np.argmin(np.abs(self.Lambda0[j]-arr))
-            self.Lambda0s_grid.append(arr[i])
-
-        old_data = self.check_old_data_is_okay_to_use()
-        if old_data is not False:
-            self.data = old_data
-            return
-
-        self.inititate_search_object()
-
-        logging.info('Total number of grid points is {}'.format(
-            self.input_data_generator_len))
-
-        data = []
-        for vals in tqdm(self.input_data_generator,
-                         total=self.input_data_generator_len):
-            if np.sum(np.array(vals) != np.array(self.Lambda0s_grid)) < 3:
-                FS = self.search.get_det_stat(*vals)
-                data.append(list(vals) + [FS])
-            else:
-                data.append(list(vals) + [0])
-
-        data = np.array(data, dtype=np.float)
-        if return_data:
-            return data
-        else:
-            logging.info('Saving data to {}'.format(self.out_file))
-            np.savetxt(self.out_file, data, delimiter=' ')
-            self.data = data
+                'Lambda0 must be of length {}'.format(len(self.search_keys)))
+
+    def run(self, factor=2):
+        lbdim = 0.5 * factor   # size of left/bottom margin
+        trdim = 0.2 * factor   # size of top/right margin
+        whspace = 0.05         # w/hspace size
+        plotdim = factor * self.ndim + factor * (self.ndim - 1.) * whspace
+        dim = lbdim + plotdim + trdim
+
+        fig, axes = plt.subplots(self.ndim, self.ndim, figsize=(dim, dim))
+
+        # Format the figure.
+        lb = lbdim / dim
+        tr = (lbdim + plotdim) / dim
+        fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr,
+                            wspace=whspace, hspace=whspace)
+
+        search = GridSearch(
+            self.label, self.outdir, self.sftfilepattern,
+            F0s=self.Lambda0[0], F1s=self.Lambda0[1], F2s=self.F2s[0],
+            Alphas=self.Lambda0[2], Deltas=self.Lambda0[3], tref=self.tref)
+
+        for i, ikey in enumerate(self.search_keys):
+            setattr(search, ikey+'s', self.thetas[i])
+            search.run()
+            axes[i, i] = search.plot_1D(ikey, ax=axes[i, i], savefig=False)
+            setattr(search, ikey+'s', [self.Lambda0[i]])
+
+            for j, jkey in enumerate(self.search_keys):
+                ax = axes[i, j]
+
+                if j > i:
+                    ax.set_frame_on(False)
+                    ax.set_xticks([])
+                    ax.set_yticks([])
+                    continue
+
+                ax.get_shared_x_axes().join(axes[self.ndim-1, j], ax)
+                if i < self.ndim - 1:
+                    ax.set_xticklabels([])
+                if j < i:
+                    ax.get_shared_y_axes().join(axes[i, i-1], ax)
+                    if j > 0:
+                        ax.set_yticklabels([])
+                if j == i:
+                    continue
+
+                max_n_ticks = 3
+                ax.xaxis.set_major_locator(
+                    matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper"))
+                ax.yaxis.set_major_locator(
+                    matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper"))
+
+                setattr(search, ikey+'s', self.thetas[i])
+                setattr(search, jkey+'s', self.thetas[j])
+                search.run()
+                ax = search.plot_2D(jkey, ikey, ax=ax, save=False)
+                setattr(search, ikey+'s', [self.Lambda0[i]])
+                setattr(search, jkey+'s', [self.Lambda0[j]])
+
+        fig.savefig(
+            '{}/{}_slice_projection.png'.format(self.outdir, self.label))
 
 
 class GridUniformPriorSearch():