From 241d98c7e1a07b46dd61103b3ca2cadec07c2ca3 Mon Sep 17 00:00:00 2001 From: "gregory.ashton" <gregory.ashton@ligo.org> Date: Thu, 13 Oct 2016 16:23:59 +0200 Subject: [PATCH] Two improvements to general running of the searches 1) Make tqdm import allow error and use it for the grid search as well 2) Allow functionality to change the flatten method in the 2D plot --- pyfstat.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/pyfstat.py b/pyfstat.py index a38b0a4..3397b45 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -21,6 +21,12 @@ import dill as pickle import lal import lalpulsar +try: + from tqdm import tqdm +except ImportError: + def tqdm(x): + return x + plt.rcParams['text.usetex'] = True plt.rcParams['axes.formatter.useoffset'] = False @@ -670,12 +676,8 @@ class MCMCSearch(BaseSearchClass): return p0 def run_sampler_with_progress_bar(self, sampler, ns, p0): - try: - from tqdm import tqdm - for result in tqdm(sampler.sample(p0, iterations=ns), total=ns): - pass - except ImportError: - sampler.run_mcmc(p0, ns) + for result in tqdm(sampler.sample(p0, iterations=ns), total=ns): + pass return sampler def run(self, proposal_scale_factor=2): @@ -1544,7 +1546,7 @@ class GridSearch(BaseSearchClass): len(self.input_data))) data = [] - for vals in self.input_data: + for vals in tqdm(self.input_data): FS = self.search.run_computefstatistic_single_point(*vals) data.append(list(vals) + [FS]) @@ -1591,7 +1593,8 @@ class GridSearch(BaseSearchClass): fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, - add_mismatch=None, xN=None, yN=None, flat_keys=[]): + add_mismatch=None, xN=None, yN=None, flat_keys=[], + flatten_method=np.max): """ Plots a 2D grid of 2F values Parameters @@ -1599,6 +1602,8 @@ class GridSearch(BaseSearchClass): add_mismatch: tuple (xhat, yhat, Tseg) If not None, add a secondary axis with the metric mismatch from the point xhat, yhat with duration Tseg + flatten_method: np.max + Function to use in flattening the flat_keys """ if ax is None: fig, ax = plt.subplots() @@ -1616,7 +1621,7 @@ class GridSearch(BaseSearchClass): Z = z.reshape(shape) while Z.ndim > 2: - Z = np.mean(Z, axis=-1) + Z = flatten_method(Z, axis=-1) pax = ax.pcolormesh(X, Y, Z, cmap=plt.cm.viridis, vmin=vmin, vmax=vmax) plt.colorbar(pax, ax=ax) -- GitLab