Skip to content
Snippets Groups Projects
Commit 241d98c7 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Two improvements to general running of the searches

1) Make tqdm import allow error and use it for the grid search as well
2) Allow functionality to change the flatten method in the 2D plot
parent a666ce23
No related branches found
No related tags found
No related merge requests found
...@@ -21,6 +21,12 @@ import dill as pickle ...@@ -21,6 +21,12 @@ import dill as pickle
import lal import lal
import lalpulsar import lalpulsar
try:
from tqdm import tqdm
except ImportError:
def tqdm(x):
return x
plt.rcParams['text.usetex'] = True plt.rcParams['text.usetex'] = True
plt.rcParams['axes.formatter.useoffset'] = False plt.rcParams['axes.formatter.useoffset'] = False
...@@ -670,12 +676,8 @@ class MCMCSearch(BaseSearchClass): ...@@ -670,12 +676,8 @@ class MCMCSearch(BaseSearchClass):
return p0 return p0
def run_sampler_with_progress_bar(self, sampler, ns, p0): def run_sampler_with_progress_bar(self, sampler, ns, p0):
try:
from tqdm import tqdm
for result in tqdm(sampler.sample(p0, iterations=ns), total=ns): for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
pass pass
except ImportError:
sampler.run_mcmc(p0, ns)
return sampler return sampler
def run(self, proposal_scale_factor=2): def run(self, proposal_scale_factor=2):
...@@ -1544,7 +1546,7 @@ class GridSearch(BaseSearchClass): ...@@ -1544,7 +1546,7 @@ class GridSearch(BaseSearchClass):
len(self.input_data))) len(self.input_data)))
data = [] data = []
for vals in self.input_data: for vals in tqdm(self.input_data):
FS = self.search.run_computefstatistic_single_point(*vals) FS = self.search.run_computefstatistic_single_point(*vals)
data.append(list(vals) + [FS]) data.append(list(vals) + [FS])
...@@ -1591,7 +1593,8 @@ class GridSearch(BaseSearchClass): ...@@ -1591,7 +1593,8 @@ class GridSearch(BaseSearchClass):
fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None,
add_mismatch=None, xN=None, yN=None, flat_keys=[]): add_mismatch=None, xN=None, yN=None, flat_keys=[],
flatten_method=np.max):
""" Plots a 2D grid of 2F values """ Plots a 2D grid of 2F values
Parameters Parameters
...@@ -1599,6 +1602,8 @@ class GridSearch(BaseSearchClass): ...@@ -1599,6 +1602,8 @@ class GridSearch(BaseSearchClass):
add_mismatch: tuple (xhat, yhat, Tseg) add_mismatch: tuple (xhat, yhat, Tseg)
If not None, add a secondary axis with the metric mismatch from the If not None, add a secondary axis with the metric mismatch from the
point xhat, yhat with duration Tseg point xhat, yhat with duration Tseg
flatten_method: np.max
Function to use in flattening the flat_keys
""" """
if ax is None: if ax is None:
fig, ax = plt.subplots() fig, ax = plt.subplots()
...@@ -1616,7 +1621,7 @@ class GridSearch(BaseSearchClass): ...@@ -1616,7 +1621,7 @@ class GridSearch(BaseSearchClass):
Z = z.reshape(shape) Z = z.reshape(shape)
while Z.ndim > 2: while Z.ndim > 2:
Z = np.mean(Z, axis=-1) Z = flatten_method(Z, axis=-1)
pax = ax.pcolormesh(X, Y, Z, cmap=plt.cm.viridis, vmin=vmin, vmax=vmax) pax = ax.pcolormesh(X, Y, Z, cmap=plt.cm.viridis, vmin=vmin, vmax=vmax)
plt.colorbar(pax, ax=ax) plt.colorbar(pax, ax=ax)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment