From 241d98c7e1a07b46dd61103b3ca2cadec07c2ca3 Mon Sep 17 00:00:00 2001
From: "gregory.ashton" <gregory.ashton@ligo.org>
Date: Thu, 13 Oct 2016 16:23:59 +0200
Subject: [PATCH] Two improvements to general running of the searches

1) Make tqdm import allow error and use it for the grid search as well
2) Allow functionality to change the flatten method in the 2D plot
---
 pyfstat.py | 23 ++++++++++++++---------
 1 file changed, 14 insertions(+), 9 deletions(-)

diff --git a/pyfstat.py b/pyfstat.py
index a38b0a4..3397b45 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -21,6 +21,12 @@ import dill as pickle
 import lal
 import lalpulsar
 
+try:
+    from tqdm import tqdm
+except ImportError:
+    def tqdm(x):
+        return x
+
 plt.rcParams['text.usetex'] = True
 plt.rcParams['axes.formatter.useoffset'] = False
 
@@ -670,12 +676,8 @@ class MCMCSearch(BaseSearchClass):
         return p0
 
     def run_sampler_with_progress_bar(self, sampler, ns, p0):
-        try:
-            from tqdm import tqdm
-            for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
-                pass
-        except ImportError:
-            sampler.run_mcmc(p0, ns)
+        for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
+            pass
         return sampler
 
     def run(self, proposal_scale_factor=2):
@@ -1544,7 +1546,7 @@ class GridSearch(BaseSearchClass):
             len(self.input_data)))
 
         data = []
-        for vals in self.input_data:
+        for vals in tqdm(self.input_data):
             FS = self.search.run_computefstatistic_single_point(*vals)
             data.append(list(vals) + [FS])
 
@@ -1591,7 +1593,8 @@ class GridSearch(BaseSearchClass):
         fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
 
     def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None,
-                add_mismatch=None, xN=None, yN=None, flat_keys=[]):
+                add_mismatch=None, xN=None, yN=None, flat_keys=[],
+                flatten_method=np.max):
         """ Plots a 2D grid of 2F values
 
         Parameters
@@ -1599,6 +1602,8 @@ class GridSearch(BaseSearchClass):
         add_mismatch: tuple (xhat, yhat, Tseg)
             If not None, add a secondary axis with the metric mismatch from the
             point xhat, yhat with duration Tseg
+        flatten_method: np.max
+            Function to use in flattening the flat_keys
         """
         if ax is None:
             fig, ax = plt.subplots()
@@ -1616,7 +1621,7 @@ class GridSearch(BaseSearchClass):
         Z = z.reshape(shape)
 
         while Z.ndim > 2:
-            Z = np.mean(Z, axis=-1)
+            Z = flatten_method(Z, axis=-1)
 
         pax = ax.pcolormesh(X, Y, Z, cmap=plt.cm.viridis, vmin=vmin, vmax=vmax)
         plt.colorbar(pax, ax=ax)
-- 
GitLab