diff --git a/pyfstat.py b/pyfstat.py index a38b0a421330efb3824d4b7e976f515dc472c51a..3397b45c856281bcfca44b11936b721f8e66e667 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -21,6 +21,12 @@ import dill as pickle import lal import lalpulsar +try: + from tqdm import tqdm +except ImportError: + def tqdm(x): + return x + plt.rcParams['text.usetex'] = True plt.rcParams['axes.formatter.useoffset'] = False @@ -670,12 +676,8 @@ class MCMCSearch(BaseSearchClass): return p0 def run_sampler_with_progress_bar(self, sampler, ns, p0): - try: - from tqdm import tqdm - for result in tqdm(sampler.sample(p0, iterations=ns), total=ns): - pass - except ImportError: - sampler.run_mcmc(p0, ns) + for result in tqdm(sampler.sample(p0, iterations=ns), total=ns): + pass return sampler def run(self, proposal_scale_factor=2): @@ -1544,7 +1546,7 @@ class GridSearch(BaseSearchClass): len(self.input_data))) data = [] - for vals in self.input_data: + for vals in tqdm(self.input_data): FS = self.search.run_computefstatistic_single_point(*vals) data.append(list(vals) + [FS]) @@ -1591,7 +1593,8 @@ class GridSearch(BaseSearchClass): fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, - add_mismatch=None, xN=None, yN=None, flat_keys=[]): + add_mismatch=None, xN=None, yN=None, flat_keys=[], + flatten_method=np.max): """ Plots a 2D grid of 2F values Parameters @@ -1599,6 +1602,8 @@ class GridSearch(BaseSearchClass): add_mismatch: tuple (xhat, yhat, Tseg) If not None, add a secondary axis with the metric mismatch from the point xhat, yhat with duration Tseg + flatten_method: np.max + Function to use in flattening the flat_keys """ if ax is None: fig, ax = plt.subplots() @@ -1616,7 +1621,7 @@ class GridSearch(BaseSearchClass): Z = z.reshape(shape) while Z.ndim > 2: - Z = np.mean(Z, axis=-1) + Z = flatten_method(Z, axis=-1) pax = ax.pcolormesh(X, Y, Z, cmap=plt.cm.viridis, vmin=vmin, vmax=vmax) plt.colorbar(pax, ax=ax)