Commit 241d98c7 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Two improvements to general running of the searches

1) Make tqdm import allow error and use it for the grid search as well
2) Allow functionality to change the flatten method in the 2D plot
parent a666ce23
......@@ -21,6 +21,12 @@ import dill as pickle
import lal
import lalpulsar
from tqdm import tqdm
except ImportError:
def tqdm(x):
return x
plt.rcParams['text.usetex'] = True
plt.rcParams['axes.formatter.useoffset'] = False
......@@ -670,12 +676,8 @@ class MCMCSearch(BaseSearchClass):
return p0
def run_sampler_with_progress_bar(self, sampler, ns, p0):
from tqdm import tqdm
for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
except ImportError:
sampler.run_mcmc(p0, ns)
return sampler
def run(self, proposal_scale_factor=2):
......@@ -1544,7 +1546,7 @@ class GridSearch(BaseSearchClass):
data = []
for vals in self.input_data:
for vals in tqdm(self.input_data):
FS =*vals)
data.append(list(vals) + [FS])
......@@ -1591,7 +1593,8 @@ class GridSearch(BaseSearchClass):
fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None,
add_mismatch=None, xN=None, yN=None, flat_keys=[]):
add_mismatch=None, xN=None, yN=None, flat_keys=[],
""" Plots a 2D grid of 2F values
......@@ -1599,6 +1602,8 @@ class GridSearch(BaseSearchClass):
add_mismatch: tuple (xhat, yhat, Tseg)
If not None, add a secondary axis with the metric mismatch from the
point xhat, yhat with duration Tseg
flatten_method: np.max
Function to use in flattening the flat_keys
if ax is None:
fig, ax = plt.subplots()
......@@ -1616,7 +1621,7 @@ class GridSearch(BaseSearchClass):
Z = z.reshape(shape)
while Z.ndim > 2:
Z = np.mean(Z, axis=-1)
Z = flatten_method(Z, axis=-1)
pax = ax.pcolormesh(X, Y, Z,, vmin=vmin, vmax=vmax)
plt.colorbar(pax, ax=ax)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment