diff --git a/pyfstat.py b/pyfstat.py
index a38b0a421330efb3824d4b7e976f515dc472c51a..3397b45c856281bcfca44b11936b721f8e66e667 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -21,6 +21,12 @@ import dill as pickle
 import lal
 import lalpulsar
 
+try:
+    from tqdm import tqdm
+except ImportError:
+    def tqdm(x):
+        return x
+
 plt.rcParams['text.usetex'] = True
 plt.rcParams['axes.formatter.useoffset'] = False
 
@@ -670,12 +676,8 @@ class MCMCSearch(BaseSearchClass):
         return p0
 
     def run_sampler_with_progress_bar(self, sampler, ns, p0):
-        try:
-            from tqdm import tqdm
-            for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
-                pass
-        except ImportError:
-            sampler.run_mcmc(p0, ns)
+        for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
+            pass
         return sampler
 
     def run(self, proposal_scale_factor=2):
@@ -1544,7 +1546,7 @@ class GridSearch(BaseSearchClass):
             len(self.input_data)))
 
         data = []
-        for vals in self.input_data:
+        for vals in tqdm(self.input_data):
             FS = self.search.run_computefstatistic_single_point(*vals)
             data.append(list(vals) + [FS])
 
@@ -1591,7 +1593,8 @@ class GridSearch(BaseSearchClass):
         fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
 
     def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None,
-                add_mismatch=None, xN=None, yN=None, flat_keys=[]):
+                add_mismatch=None, xN=None, yN=None, flat_keys=[],
+                flatten_method=np.max):
         """ Plots a 2D grid of 2F values
 
         Parameters
@@ -1599,6 +1602,8 @@ class GridSearch(BaseSearchClass):
         add_mismatch: tuple (xhat, yhat, Tseg)
             If not None, add a secondary axis with the metric mismatch from the
             point xhat, yhat with duration Tseg
+        flatten_method: np.max
+            Function to use in flattening the flat_keys
         """
         if ax is None:
             fig, ax = plt.subplots()
@@ -1616,7 +1621,7 @@ class GridSearch(BaseSearchClass):
         Z = z.reshape(shape)
 
         while Z.ndim > 2:
-            Z = np.mean(Z, axis=-1)
+            Z = flatten_method(Z, axis=-1)
 
         pax = ax.pcolormesh(X, Y, Z, cmap=plt.cm.viridis, vmin=vmin, vmax=vmax)
         plt.colorbar(pax, ax=ax)