Commit a243f8e4 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Revert to array method and improve plotting

- Adds back in array method for caching data
parent 17486105
......@@ -23,7 +23,7 @@ class GridSearch(BaseSearchClass):
tex_labels = {'F0': '$f$', 'F1': '$\dot{f}$', 'F2': '$\ddot{f}$',
'Alpha': r'$\alpha$', 'Delta': r'$\delta$'}
tex_labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$', 'F2': '$-\ddot{f}_0$',
'Alpha': r'$-\alpha_0$', 'Deltas': r'$-\delta_0$'}
'Alpha': r'$-\alpha_0$', 'Delta': r'$-\delta_0$'}
@helper_functions.initializer
def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
......@@ -100,8 +100,10 @@ class GridSearch(BaseSearchClass):
self.F1s, self.F2s, self.Alphas, self.Deltas):
coord_arrays.append(self.get_array_from_tuple(tup))
self.input_data_generator_len = np.prod([len(k) for k in coord_arrays])
self.input_data_generator = itertools.product(*coord_arrays)
input_data = []
for vals in itertools.product(*coord_arrays):
input_data.append(vals)
self.input_data = np.array(input_data)
self.coord_arrays = coord_arrays
def check_old_data_is_okay_to_use(self):
......@@ -118,7 +120,15 @@ class GridSearch(BaseSearchClass):
+ ' continuing with grid search')
return False
logging.info('No data caching available')
data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
if np.all(data[:, 0:-1] == self.input_data):
logging.info(
'Old data found with matching input, no search performed')
return data
else:
logging.info(
'Old data found, input differs, continuing with grid search')
return False
return False
def run(self, return_data=False):
......@@ -132,8 +142,7 @@ class GridSearch(BaseSearchClass):
self.inititate_search_object()
data = []
for vals in tqdm(self.input_data_generator,
total=self.input_data_generator_len):
for vals in tqdm(self.input_data):
FS = self.search.get_det_stat(*vals)
data.append(list(vals) + [FS])
......@@ -176,6 +185,8 @@ class GridSearch(BaseSearchClass):
fig, ax = plt.subplots()
xidx = self.keys.index(xkey)
x = np.unique(self.data[:, xidx])
if x0:
x = x - x0
z = self.data[:, -1]
ax.plot(x, z)
if x0:
......@@ -335,10 +346,9 @@ class SliceGridSearch(GridSearch):
raise ValueError(
'Lambda0 must be of length {}'.format(len(self.search_keys)))
def run(self, factor=2):
def run(self, factor=2, max_n_ticks=4, whspace=0.07, save=True):
lbdim = 0.5 * factor # size of left/bottom margin
trdim = 0.2 * factor # size of top/right margin
whspace = 0.05 # w/hspace size
trdim = 0.4 * factor # size of top/right margin
plotdim = factor * self.ndim + factor * (self.ndim - 1.) * whspace
dim = lbdim + plotdim + trdim
......@@ -353,13 +363,21 @@ class SliceGridSearch(GridSearch):
search = GridSearch(
self.label, self.outdir, self.sftfilepattern,
F0s=self.Lambda0[0], F1s=self.Lambda0[1], F2s=self.F2s[0],
Alphas=self.Lambda0[2], Deltas=self.Lambda0[3], tref=self.tref)
Alphas=self.Lambda0[2], Deltas=self.Lambda0[3], tref=self.tref,
minStartTime=self.minStartTime, maxStartTime=self.maxStartTime)
for i, ikey in enumerate(self.search_keys):
setattr(search, ikey+'s', self.thetas[i])
search.label = '{}_{}'.format(self.label, ikey)
search.set_out_file()
search.run()
axes[i, i] = search.plot_1D(ikey, ax=axes[i, i], savefig=False)
axes[i, i] = search.plot_1D(ikey, ax=axes[i, i], savefig=False,
x0=self.Lambda0[i]
)
setattr(search, ikey+'s', [self.Lambda0[i]])
axes[i, i].yaxis.tick_right()
axes[i, i].yaxis.set_label_position("right")
axes[i, i].set_xlabel('')
for j, jkey in enumerate(self.search_keys):
ax = axes[i, j]
......@@ -380,7 +398,6 @@ class SliceGridSearch(GridSearch):
if j == i:
continue
max_n_ticks = 3
ax.xaxis.set_major_locator(
matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper"))
ax.yaxis.set_major_locator(
......@@ -388,13 +405,32 @@ class SliceGridSearch(GridSearch):
setattr(search, ikey+'s', self.thetas[i])
setattr(search, jkey+'s', self.thetas[j])
search.label = '{}_{}'.format(self.label, ikey+jkey)
search.set_out_file()
search.run()
ax = search.plot_2D(jkey, ikey, ax=ax, save=False)
ax = search.plot_2D(jkey, ikey, ax=ax, save=False,
y0=self.Lambda0[i], x0=self.Lambda0[j]
)
setattr(search, ikey+'s', [self.Lambda0[i]])
setattr(search, jkey+'s', [self.Lambda0[j]])
fig.savefig(
'{}/{}_slice_projection.png'.format(self.outdir, self.label))
ax.grid(lw=0.2, ls='--', zorder=10)
ax.set_xlabel('')
ax.set_ylabel('')
for i, ikey in enumerate(self.search_keys):
axes[-1, i].set_xlabel(
self.tex_labels[ikey]+self.tex_labels0[ikey])
if i > 0:
axes[i, 0].set_ylabel(
self.tex_labels[ikey]+self.tex_labels0[ikey])
axes[i, i].set_ylabel("$2\mathcal{F}$")
if save:
fig.savefig(
'{}/{}_slice_projection.png'.format(self.outdir, self.label))
else:
return fig, axes
class GridUniformPriorSearch():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment