Commit 17486105 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Clean up of SliceGridSearch

- Just run separate 2D and 1D searches for each panel itself
- Some additional clean up of the basic GridSearch method
parent 6580d08a
...@@ -20,12 +20,17 @@ import lal ...@@ -20,12 +20,17 @@ import lal
class GridSearch(BaseSearchClass): class GridSearch(BaseSearchClass):
""" Gridded search using ComputeFstat """ """ Gridded search using ComputeFstat """
tex_labels = {'F0': '$f$', 'F1': '$\dot{f}$', 'F2': '$\ddot{f}$',
'Alpha': r'$\alpha$', 'Delta': r'$\delta$'}
tex_labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$', 'F2': '$-\ddot{f}_0$',
'Alpha': r'$-\alpha_0$', 'Deltas': r'$-\delta_0$'}
@helper_functions.initializer @helper_functions.initializer
def __init__(self, label, outdir, sftfilepattern, F0s=[0], F1s=[0], F2s=[0], def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
Alphas=[0], Deltas=[0], tref=None, minStartTime=None, Deltas, tref=None, minStartTime=None, maxStartTime=None,
maxStartTime=None, nsegs=1, BSGL=False, minCoverFreq=None, nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None,
maxCoverFreq=None, detectors=None, SSBprec=None, detectors=None, SSBprec=None, injectSources=None,
injectSources=None, input_arrays=False, assumeSqrtSX=None): input_arrays=False, assumeSqrtSX=None):
""" """
Parameters Parameters
---------- ----------
...@@ -50,6 +55,9 @@ class GridSearch(BaseSearchClass): ...@@ -50,6 +55,9 @@ class GridSearch(BaseSearchClass):
os.mkdir(outdir) os.mkdir(outdir)
self.set_out_file() self.set_out_file()
self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta'] self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
self.search_keys = [x+'s' for x in self.keys[2:]]
for k in self.search_keys:
setattr(self, k, np.atleast_1d(getattr(self, k)))
def inititate_search_object(self): def inititate_search_object(self):
logging.info('Setting up search object') logging.info('Setting up search object')
...@@ -88,8 +96,8 @@ class GridSearch(BaseSearchClass): ...@@ -88,8 +96,8 @@ class GridSearch(BaseSearchClass):
def get_input_data_array(self): def get_input_data_array(self):
logging.info("Generating input data array") logging.info("Generating input data array")
coord_arrays = [] coord_arrays = []
for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s, for tup in ([self.minStartTime], [self.maxStartTime], self.F0s,
self.Alphas, self.Deltas): self.F1s, self.F2s, self.Alphas, self.Deltas):
coord_arrays.append(self.get_array_from_tuple(tup)) coord_arrays.append(self.get_array_from_tuple(tup))
self.input_data_generator_len = np.prod([len(k) for k in coord_arrays]) self.input_data_generator_len = np.prod([len(k) for k in coord_arrays])
...@@ -112,15 +120,6 @@ class GridSearch(BaseSearchClass): ...@@ -112,15 +120,6 @@ class GridSearch(BaseSearchClass):
logging.info('No data caching available') logging.info('No data caching available')
return False return False
#data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
#if np.all(data[:, 0:-1] == self.input_data):
# logging.info(
# 'Old data found with matching input, no search performed')
# return data
#else:
# logging.info(
# 'Old data found, input differs, continuing with grid search')
# return False
def run(self, return_data=False): def run(self, return_data=False):
self.get_input_data_array() self.get_input_data_array()
...@@ -129,10 +128,12 @@ class GridSearch(BaseSearchClass): ...@@ -129,10 +128,12 @@ class GridSearch(BaseSearchClass):
self.data = old_data self.data = old_data
return return
self.inititate_search_object() if hasattr(self, 'search') is False:
self.inititate_search_object()
data = [] data = []
for vals in tqdm(self.input_data_generator): for vals in tqdm(self.input_data_generator,
total=self.input_data_generator_len):
FS = self.search.get_det_stat(*vals) FS = self.search.get_det_stat(*vals)
data.append(list(vals) + [FS]) data.append(list(vals) + [FS])
...@@ -170,18 +171,27 @@ class GridSearch(BaseSearchClass): ...@@ -170,18 +171,27 @@ class GridSearch(BaseSearchClass):
m = self.convert_F1_to_mismatch(y, yhat, Tseg) m = self.convert_F1_to_mismatch(y, yhat, Tseg)
axY.set_ylim(m[0], m[-1]) axY.set_ylim(m[0], m[-1])
def plot_1D(self, xkey): def plot_1D(self, xkey, ax=None, x0=None, savefig=True):
fig, ax = plt.subplots() if ax is None:
fig, ax = plt.subplots()
xidx = self.keys.index(xkey) xidx = self.keys.index(xkey)
x = np.unique(self.data[:, xidx]) x = np.unique(self.data[:, xidx])
z = self.data[:, -1] z = self.data[:, -1]
plt.plot(x, z) ax.plot(x, z)
fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) if x0:
ax.set_xlabel(self.tex_labels[xkey]+self.tex_labels0[xkey])
else:
ax.set_xlabel(self.tex_labels[xkey])
if savefig:
fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
else:
return ax
def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None,
add_mismatch=None, xN=None, yN=None, flat_keys=[], add_mismatch=None, xN=None, yN=None, flat_keys=[],
rel_flat_idxs=[], flatten_method=np.max, title=None, rel_flat_idxs=[], flatten_method=np.max, title=None,
predicted_twoF=None, cm=None, cbarkwargs={}, x0=None, y0=None): predicted_twoF=None, cm=None, cbarkwargs={}, x0=None, y0=None,
colorbar=False):
""" Plots a 2D grid of 2F values """ Plots a 2D grid of 2F values
Parameters Parameters
...@@ -223,24 +233,23 @@ class GridSearch(BaseSearchClass): ...@@ -223,24 +233,23 @@ class GridSearch(BaseSearchClass):
cm = plt.cm.viridis cm = plt.cm.viridis
pax = ax.pcolormesh(X, Y, Z, cmap=cm, vmin=vmin, vmax=vmax) pax = ax.pcolormesh(X, Y, Z, cmap=cm, vmin=vmin, vmax=vmax)
cb = plt.colorbar(pax, ax=ax, **cbarkwargs) if colorbar:
cb.set_label('$2\mathcal{F}$') cb = plt.colorbar(pax, ax=ax, **cbarkwargs)
cb.set_label('$2\mathcal{F}$')
if add_mismatch: if add_mismatch:
self.add_mismatch_to_ax(ax, x, y, xkey, ykey, *add_mismatch) self.add_mismatch_to_ax(ax, x, y, xkey, ykey, *add_mismatch)
ax.set_xlim(x[0], x[-1]) ax.set_xlim(x[0], x[-1])
ax.set_ylim(y[0], y[-1]) ax.set_ylim(y[0], y[-1])
labels = {'F0': '$f$', 'F1': '$\dot{f}$'}
labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$'}
if x0: if x0:
ax.set_xlabel(labels[xkey]+labels0[xkey]) ax.set_xlabel(self.tex_labels[xkey]+self.tex_labels0[xkey])
else: else:
ax.set_xlabel(labels[xkey]) ax.set_xlabel(self.tex_labels[xkey])
if y0: if y0:
ax.set_ylabel(labels[ykey]+labels0[ykey]) ax.set_ylabel(self.tex_labels[ykey]+self.tex_labels0[ykey])
else: else:
ax.set_ylabel(labels[ykey]) ax.set_ylabel(self.tex_labels[ykey])
if title: if title:
ax.set_title(title) ax.set_title(title)
...@@ -287,12 +296,11 @@ class GridSearch(BaseSearchClass): ...@@ -287,12 +296,11 @@ class GridSearch(BaseSearchClass):
class SliceGridSearch(GridSearch): class SliceGridSearch(GridSearch):
""" Slice gridded search using ComputeFstat """ """ Slice gridded search using ComputeFstat """
@helper_functions.initializer @helper_functions.initializer
def __init__(self, label, outdir, sftfilepattern, F0s=[0], F1s=[0], F2s=[0], def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
Alphas=[0], Deltas=[0], tref=None, minStartTime=None, Deltas, tref=None, minStartTime=None, maxStartTime=None,
maxStartTime=None, nsegs=1, BSGL=False, minCoverFreq=None, nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None,
maxCoverFreq=None, detectors=None, SSBprec=None, detectors=None, SSBprec=None, injectSources=None,
injectSources=None, input_arrays=False, assumeSqrtSX=None, input_arrays=False, assumeSqrtSX=None, Lambda0=None):
Lambda0=None):
""" """
Parameters Parameters
---------- ----------
...@@ -317,46 +325,76 @@ class SliceGridSearch(GridSearch): ...@@ -317,46 +325,76 @@ class SliceGridSearch(GridSearch):
os.mkdir(outdir) os.mkdir(outdir)
self.set_out_file() self.set_out_file()
self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta'] self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
self.ndim = 0
self.thetas = [F0s, F1s, Alphas, Deltas]
self.ndim = 4
self.search_keys = ['F0', 'F1', 'Alpha', 'Delta']
self.Lambda0 = np.array(Lambda0) self.Lambda0 = np.array(Lambda0)
if len(self.Lambda0) != len(self.keys): if len(self.Lambda0) != len(self.search_keys):
raise ValueError( raise ValueError(
'Lambda0 must be of length {}'.format(len(self.keys))) 'Lambda0 must be of length {}'.format(len(self.search_keys)))
def run(self, return_data=False): def run(self, factor=2):
self.get_input_data_array() lbdim = 0.5 * factor # size of left/bottom margin
trdim = 0.2 * factor # size of top/right margin
self.Lambda0s_grid = [] whspace = 0.05 # w/hspace size
for j, arr in enumerate(self.coord_arrays): plotdim = factor * self.ndim + factor * (self.ndim - 1.) * whspace
i = np.argmin(np.abs(self.Lambda0[j]-arr)) dim = lbdim + plotdim + trdim
self.Lambda0s_grid.append(arr[i])
fig, axes = plt.subplots(self.ndim, self.ndim, figsize=(dim, dim))
old_data = self.check_old_data_is_okay_to_use()
if old_data is not False: # Format the figure.
self.data = old_data lb = lbdim / dim
return tr = (lbdim + plotdim) / dim
fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr,
self.inititate_search_object() wspace=whspace, hspace=whspace)
logging.info('Total number of grid points is {}'.format( search = GridSearch(
self.input_data_generator_len)) self.label, self.outdir, self.sftfilepattern,
F0s=self.Lambda0[0], F1s=self.Lambda0[1], F2s=self.F2s[0],
data = [] Alphas=self.Lambda0[2], Deltas=self.Lambda0[3], tref=self.tref)
for vals in tqdm(self.input_data_generator,
total=self.input_data_generator_len): for i, ikey in enumerate(self.search_keys):
if np.sum(np.array(vals) != np.array(self.Lambda0s_grid)) < 3: setattr(search, ikey+'s', self.thetas[i])
FS = self.search.get_det_stat(*vals) search.run()
data.append(list(vals) + [FS]) axes[i, i] = search.plot_1D(ikey, ax=axes[i, i], savefig=False)
else: setattr(search, ikey+'s', [self.Lambda0[i]])
data.append(list(vals) + [0])
for j, jkey in enumerate(self.search_keys):
data = np.array(data, dtype=np.float) ax = axes[i, j]
if return_data:
return data if j > i:
else: ax.set_frame_on(False)
logging.info('Saving data to {}'.format(self.out_file)) ax.set_xticks([])
np.savetxt(self.out_file, data, delimiter=' ') ax.set_yticks([])
self.data = data continue
ax.get_shared_x_axes().join(axes[self.ndim-1, j], ax)
if i < self.ndim - 1:
ax.set_xticklabels([])
if j < i:
ax.get_shared_y_axes().join(axes[i, i-1], ax)
if j > 0:
ax.set_yticklabels([])
if j == i:
continue
max_n_ticks = 3
ax.xaxis.set_major_locator(
matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper"))
ax.yaxis.set_major_locator(
matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper"))
setattr(search, ikey+'s', self.thetas[i])
setattr(search, jkey+'s', self.thetas[j])
search.run()
ax = search.plot_2D(jkey, ikey, ax=ax, save=False)
setattr(search, ikey+'s', [self.Lambda0[i]])
setattr(search, jkey+'s', [self.Lambda0[j]])
fig.savefig(
'{}/{}_slice_projection.png'.format(self.outdir, self.label))
class GridUniformPriorSearch(): class GridUniformPriorSearch():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment