Skip to content
Snippets Groups Projects
Commit a243f8e4 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Revert to array method and improve plotting

- Adds back in array method for caching data
parent 17486105
No related branches found
No related tags found
No related merge requests found
...@@ -23,7 +23,7 @@ class GridSearch(BaseSearchClass): ...@@ -23,7 +23,7 @@ class GridSearch(BaseSearchClass):
tex_labels = {'F0': '$f$', 'F1': '$\dot{f}$', 'F2': '$\ddot{f}$', tex_labels = {'F0': '$f$', 'F1': '$\dot{f}$', 'F2': '$\ddot{f}$',
'Alpha': r'$\alpha$', 'Delta': r'$\delta$'} 'Alpha': r'$\alpha$', 'Delta': r'$\delta$'}
tex_labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$', 'F2': '$-\ddot{f}_0$', tex_labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$', 'F2': '$-\ddot{f}_0$',
'Alpha': r'$-\alpha_0$', 'Deltas': r'$-\delta_0$'} 'Alpha': r'$-\alpha_0$', 'Delta': r'$-\delta_0$'}
@helper_functions.initializer @helper_functions.initializer
def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas, def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
...@@ -100,8 +100,10 @@ class GridSearch(BaseSearchClass): ...@@ -100,8 +100,10 @@ class GridSearch(BaseSearchClass):
self.F1s, self.F2s, self.Alphas, self.Deltas): self.F1s, self.F2s, self.Alphas, self.Deltas):
coord_arrays.append(self.get_array_from_tuple(tup)) coord_arrays.append(self.get_array_from_tuple(tup))
self.input_data_generator_len = np.prod([len(k) for k in coord_arrays]) input_data = []
self.input_data_generator = itertools.product(*coord_arrays) for vals in itertools.product(*coord_arrays):
input_data.append(vals)
self.input_data = np.array(input_data)
self.coord_arrays = coord_arrays self.coord_arrays = coord_arrays
def check_old_data_is_okay_to_use(self): def check_old_data_is_okay_to_use(self):
...@@ -118,7 +120,15 @@ class GridSearch(BaseSearchClass): ...@@ -118,7 +120,15 @@ class GridSearch(BaseSearchClass):
+ ' continuing with grid search') + ' continuing with grid search')
return False return False
logging.info('No data caching available') data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
if np.all(data[:, 0:-1] == self.input_data):
logging.info(
'Old data found with matching input, no search performed')
return data
else:
logging.info(
'Old data found, input differs, continuing with grid search')
return False
return False return False
def run(self, return_data=False): def run(self, return_data=False):
...@@ -132,8 +142,7 @@ class GridSearch(BaseSearchClass): ...@@ -132,8 +142,7 @@ class GridSearch(BaseSearchClass):
self.inititate_search_object() self.inititate_search_object()
data = [] data = []
for vals in tqdm(self.input_data_generator, for vals in tqdm(self.input_data):
total=self.input_data_generator_len):
FS = self.search.get_det_stat(*vals) FS = self.search.get_det_stat(*vals)
data.append(list(vals) + [FS]) data.append(list(vals) + [FS])
...@@ -176,6 +185,8 @@ class GridSearch(BaseSearchClass): ...@@ -176,6 +185,8 @@ class GridSearch(BaseSearchClass):
fig, ax = plt.subplots() fig, ax = plt.subplots()
xidx = self.keys.index(xkey) xidx = self.keys.index(xkey)
x = np.unique(self.data[:, xidx]) x = np.unique(self.data[:, xidx])
if x0:
x = x - x0
z = self.data[:, -1] z = self.data[:, -1]
ax.plot(x, z) ax.plot(x, z)
if x0: if x0:
...@@ -335,10 +346,9 @@ class SliceGridSearch(GridSearch): ...@@ -335,10 +346,9 @@ class SliceGridSearch(GridSearch):
raise ValueError( raise ValueError(
'Lambda0 must be of length {}'.format(len(self.search_keys))) 'Lambda0 must be of length {}'.format(len(self.search_keys)))
def run(self, factor=2): def run(self, factor=2, max_n_ticks=4, whspace=0.07, save=True):
lbdim = 0.5 * factor # size of left/bottom margin lbdim = 0.5 * factor # size of left/bottom margin
trdim = 0.2 * factor # size of top/right margin trdim = 0.4 * factor # size of top/right margin
whspace = 0.05 # w/hspace size
plotdim = factor * self.ndim + factor * (self.ndim - 1.) * whspace plotdim = factor * self.ndim + factor * (self.ndim - 1.) * whspace
dim = lbdim + plotdim + trdim dim = lbdim + plotdim + trdim
...@@ -353,13 +363,21 @@ class SliceGridSearch(GridSearch): ...@@ -353,13 +363,21 @@ class SliceGridSearch(GridSearch):
search = GridSearch( search = GridSearch(
self.label, self.outdir, self.sftfilepattern, self.label, self.outdir, self.sftfilepattern,
F0s=self.Lambda0[0], F1s=self.Lambda0[1], F2s=self.F2s[0], F0s=self.Lambda0[0], F1s=self.Lambda0[1], F2s=self.F2s[0],
Alphas=self.Lambda0[2], Deltas=self.Lambda0[3], tref=self.tref) Alphas=self.Lambda0[2], Deltas=self.Lambda0[3], tref=self.tref,
minStartTime=self.minStartTime, maxStartTime=self.maxStartTime)
for i, ikey in enumerate(self.search_keys): for i, ikey in enumerate(self.search_keys):
setattr(search, ikey+'s', self.thetas[i]) setattr(search, ikey+'s', self.thetas[i])
search.label = '{}_{}'.format(self.label, ikey)
search.set_out_file()
search.run() search.run()
axes[i, i] = search.plot_1D(ikey, ax=axes[i, i], savefig=False) axes[i, i] = search.plot_1D(ikey, ax=axes[i, i], savefig=False,
x0=self.Lambda0[i]
)
setattr(search, ikey+'s', [self.Lambda0[i]]) setattr(search, ikey+'s', [self.Lambda0[i]])
axes[i, i].yaxis.tick_right()
axes[i, i].yaxis.set_label_position("right")
axes[i, i].set_xlabel('')
for j, jkey in enumerate(self.search_keys): for j, jkey in enumerate(self.search_keys):
ax = axes[i, j] ax = axes[i, j]
...@@ -380,7 +398,6 @@ class SliceGridSearch(GridSearch): ...@@ -380,7 +398,6 @@ class SliceGridSearch(GridSearch):
if j == i: if j == i:
continue continue
max_n_ticks = 3
ax.xaxis.set_major_locator( ax.xaxis.set_major_locator(
matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper")) matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper"))
ax.yaxis.set_major_locator( ax.yaxis.set_major_locator(
...@@ -388,13 +405,32 @@ class SliceGridSearch(GridSearch): ...@@ -388,13 +405,32 @@ class SliceGridSearch(GridSearch):
setattr(search, ikey+'s', self.thetas[i]) setattr(search, ikey+'s', self.thetas[i])
setattr(search, jkey+'s', self.thetas[j]) setattr(search, jkey+'s', self.thetas[j])
search.label = '{}_{}'.format(self.label, ikey+jkey)
search.set_out_file()
search.run() search.run()
ax = search.plot_2D(jkey, ikey, ax=ax, save=False) ax = search.plot_2D(jkey, ikey, ax=ax, save=False,
y0=self.Lambda0[i], x0=self.Lambda0[j]
)
setattr(search, ikey+'s', [self.Lambda0[i]]) setattr(search, ikey+'s', [self.Lambda0[i]])
setattr(search, jkey+'s', [self.Lambda0[j]]) setattr(search, jkey+'s', [self.Lambda0[j]])
ax.grid(lw=0.2, ls='--', zorder=10)
ax.set_xlabel('')
ax.set_ylabel('')
for i, ikey in enumerate(self.search_keys):
axes[-1, i].set_xlabel(
self.tex_labels[ikey]+self.tex_labels0[ikey])
if i > 0:
axes[i, 0].set_ylabel(
self.tex_labels[ikey]+self.tex_labels0[ikey])
axes[i, i].set_ylabel("$2\mathcal{F}$")
if save:
fig.savefig( fig.savefig(
'{}/{}_slice_projection.png'.format(self.outdir, self.label)) '{}/{}_slice_projection.png'.format(self.outdir, self.label))
else:
return fig, axes
class GridUniformPriorSearch(): class GridUniformPriorSearch():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment