Commit 17486105 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Clean up of SliceGridSearch

- Just run separate 2D and 1D searches for each panel itself
- Some additional clean up of the basic GridSearch method
parent 6580d08a
......@@ -20,12 +20,17 @@ import lal
class GridSearch(BaseSearchClass):
""" Gridded search using ComputeFstat """
tex_labels = {'F0': '$f$', 'F1': '$\dot{f}$', 'F2': '$\ddot{f}$',
'Alpha': r'$\alpha$', 'Delta': r'$\delta$'}
tex_labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$', 'F2': '$-\ddot{f}_0$',
'Alpha': r'$-\alpha_0$', 'Deltas': r'$-\delta_0$'}
@helper_functions.initializer
def __init__(self, label, outdir, sftfilepattern, F0s=[0], F1s=[0], F2s=[0],
Alphas=[0], Deltas=[0], tref=None, minStartTime=None,
maxStartTime=None, nsegs=1, BSGL=False, minCoverFreq=None,
maxCoverFreq=None, detectors=None, SSBprec=None,
injectSources=None, input_arrays=False, assumeSqrtSX=None):
def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
Deltas, tref=None, minStartTime=None, maxStartTime=None,
nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None,
detectors=None, SSBprec=None, injectSources=None,
input_arrays=False, assumeSqrtSX=None):
"""
Parameters
----------
......@@ -50,6 +55,9 @@ class GridSearch(BaseSearchClass):
os.mkdir(outdir)
self.set_out_file()
self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
self.search_keys = [x+'s' for x in self.keys[2:]]
for k in self.search_keys:
setattr(self, k, np.atleast_1d(getattr(self, k)))
def inititate_search_object(self):
logging.info('Setting up search object')
......@@ -88,8 +96,8 @@ class GridSearch(BaseSearchClass):
def get_input_data_array(self):
logging.info("Generating input data array")
coord_arrays = []
for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s,
self.Alphas, self.Deltas):
for tup in ([self.minStartTime], [self.maxStartTime], self.F0s,
self.F1s, self.F2s, self.Alphas, self.Deltas):
coord_arrays.append(self.get_array_from_tuple(tup))
self.input_data_generator_len = np.prod([len(k) for k in coord_arrays])
......@@ -112,15 +120,6 @@ class GridSearch(BaseSearchClass):
logging.info('No data caching available')
return False
#data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
#if np.all(data[:, 0:-1] == self.input_data):
# logging.info(
# 'Old data found with matching input, no search performed')
# return data
#else:
# logging.info(
# 'Old data found, input differs, continuing with grid search')
# return False
def run(self, return_data=False):
self.get_input_data_array()
......@@ -129,10 +128,12 @@ class GridSearch(BaseSearchClass):
self.data = old_data
return
self.inititate_search_object()
if hasattr(self, 'search') is False:
self.inititate_search_object()
data = []
for vals in tqdm(self.input_data_generator):
for vals in tqdm(self.input_data_generator,
total=self.input_data_generator_len):
FS = self.search.get_det_stat(*vals)
data.append(list(vals) + [FS])
......@@ -170,18 +171,27 @@ class GridSearch(BaseSearchClass):
m = self.convert_F1_to_mismatch(y, yhat, Tseg)
axY.set_ylim(m[0], m[-1])
def plot_1D(self, xkey):
fig, ax = plt.subplots()
def plot_1D(self, xkey, ax=None, x0=None, savefig=True):
if ax is None:
fig, ax = plt.subplots()
xidx = self.keys.index(xkey)
x = np.unique(self.data[:, xidx])
z = self.data[:, -1]
plt.plot(x, z)
fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
ax.plot(x, z)
if x0:
ax.set_xlabel(self.tex_labels[xkey]+self.tex_labels0[xkey])
else:
ax.set_xlabel(self.tex_labels[xkey])
if savefig:
fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
else:
return ax
def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None,
add_mismatch=None, xN=None, yN=None, flat_keys=[],
rel_flat_idxs=[], flatten_method=np.max, title=None,
predicted_twoF=None, cm=None, cbarkwargs={}, x0=None, y0=None):
predicted_twoF=None, cm=None, cbarkwargs={}, x0=None, y0=None,
colorbar=False):
""" Plots a 2D grid of 2F values
Parameters
......@@ -223,24 +233,23 @@ class GridSearch(BaseSearchClass):
cm = plt.cm.viridis
pax = ax.pcolormesh(X, Y, Z, cmap=cm, vmin=vmin, vmax=vmax)
cb = plt.colorbar(pax, ax=ax, **cbarkwargs)
cb.set_label('$2\mathcal{F}$')
if colorbar:
cb = plt.colorbar(pax, ax=ax, **cbarkwargs)
cb.set_label('$2\mathcal{F}$')
if add_mismatch:
self.add_mismatch_to_ax(ax, x, y, xkey, ykey, *add_mismatch)
ax.set_xlim(x[0], x[-1])
ax.set_ylim(y[0], y[-1])
labels = {'F0': '$f$', 'F1': '$\dot{f}$'}
labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$'}
if x0:
ax.set_xlabel(labels[xkey]+labels0[xkey])
ax.set_xlabel(self.tex_labels[xkey]+self.tex_labels0[xkey])
else:
ax.set_xlabel(labels[xkey])
ax.set_xlabel(self.tex_labels[xkey])
if y0:
ax.set_ylabel(labels[ykey]+labels0[ykey])
ax.set_ylabel(self.tex_labels[ykey]+self.tex_labels0[ykey])
else:
ax.set_ylabel(labels[ykey])
ax.set_ylabel(self.tex_labels[ykey])
if title:
ax.set_title(title)
......@@ -287,12 +296,11 @@ class GridSearch(BaseSearchClass):
class SliceGridSearch(GridSearch):
""" Slice gridded search using ComputeFstat """
@helper_functions.initializer
def __init__(self, label, outdir, sftfilepattern, F0s=[0], F1s=[0], F2s=[0],
Alphas=[0], Deltas=[0], tref=None, minStartTime=None,
maxStartTime=None, nsegs=1, BSGL=False, minCoverFreq=None,
maxCoverFreq=None, detectors=None, SSBprec=None,
injectSources=None, input_arrays=False, assumeSqrtSX=None,
Lambda0=None):
def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
Deltas, tref=None, minStartTime=None, maxStartTime=None,
nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None,
detectors=None, SSBprec=None, injectSources=None,
input_arrays=False, assumeSqrtSX=None, Lambda0=None):
"""
Parameters
----------
......@@ -317,46 +325,76 @@ class SliceGridSearch(GridSearch):
os.mkdir(outdir)
self.set_out_file()
self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
self.ndim = 0
self.thetas = [F0s, F1s, Alphas, Deltas]
self.ndim = 4
self.search_keys = ['F0', 'F1', 'Alpha', 'Delta']
self.Lambda0 = np.array(Lambda0)
if len(self.Lambda0) != len(self.keys):
if len(self.Lambda0) != len(self.search_keys):
raise ValueError(
'Lambda0 must be of length {}'.format(len(self.keys)))
def run(self, return_data=False):
self.get_input_data_array()
self.Lambda0s_grid = []
for j, arr in enumerate(self.coord_arrays):
i = np.argmin(np.abs(self.Lambda0[j]-arr))
self.Lambda0s_grid.append(arr[i])
old_data = self.check_old_data_is_okay_to_use()
if old_data is not False:
self.data = old_data
return
self.inititate_search_object()
logging.info('Total number of grid points is {}'.format(
self.input_data_generator_len))
data = []
for vals in tqdm(self.input_data_generator,
total=self.input_data_generator_len):
if np.sum(np.array(vals) != np.array(self.Lambda0s_grid)) < 3:
FS = self.search.get_det_stat(*vals)
data.append(list(vals) + [FS])
else:
data.append(list(vals) + [0])
data = np.array(data, dtype=np.float)
if return_data:
return data
else:
logging.info('Saving data to {}'.format(self.out_file))
np.savetxt(self.out_file, data, delimiter=' ')
self.data = data
'Lambda0 must be of length {}'.format(len(self.search_keys)))
def run(self, factor=2):
lbdim = 0.5 * factor # size of left/bottom margin
trdim = 0.2 * factor # size of top/right margin
whspace = 0.05 # w/hspace size
plotdim = factor * self.ndim + factor * (self.ndim - 1.) * whspace
dim = lbdim + plotdim + trdim
fig, axes = plt.subplots(self.ndim, self.ndim, figsize=(dim, dim))
# Format the figure.
lb = lbdim / dim
tr = (lbdim + plotdim) / dim
fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr,
wspace=whspace, hspace=whspace)
search = GridSearch(
self.label, self.outdir, self.sftfilepattern,
F0s=self.Lambda0[0], F1s=self.Lambda0[1], F2s=self.F2s[0],
Alphas=self.Lambda0[2], Deltas=self.Lambda0[3], tref=self.tref)
for i, ikey in enumerate(self.search_keys):
setattr(search, ikey+'s', self.thetas[i])
search.run()
axes[i, i] = search.plot_1D(ikey, ax=axes[i, i], savefig=False)
setattr(search, ikey+'s', [self.Lambda0[i]])
for j, jkey in enumerate(self.search_keys):
ax = axes[i, j]
if j > i:
ax.set_frame_on(False)
ax.set_xticks([])
ax.set_yticks([])
continue
ax.get_shared_x_axes().join(axes[self.ndim-1, j], ax)
if i < self.ndim - 1:
ax.set_xticklabels([])
if j < i:
ax.get_shared_y_axes().join(axes[i, i-1], ax)
if j > 0:
ax.set_yticklabels([])
if j == i:
continue
max_n_ticks = 3
ax.xaxis.set_major_locator(
matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper"))
ax.yaxis.set_major_locator(
matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper"))
setattr(search, ikey+'s', self.thetas[i])
setattr(search, jkey+'s', self.thetas[j])
search.run()
ax = search.plot_2D(jkey, ikey, ax=ax, save=False)
setattr(search, ikey+'s', [self.Lambda0[i]])
setattr(search, jkey+'s', [self.Lambda0[j]])
fig.savefig(
'{}/{}_slice_projection.png'.format(self.outdir, self.label))
class GridUniformPriorSearch():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment