diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py index a7eebcc942a20e72259b7fea41e7b406ecc15089..2a9d098249965e9e2293eb338b2ade1ef0dbb4c7 100644 --- a/pyfstat/grid_based_searches.py +++ b/pyfstat/grid_based_searches.py @@ -23,7 +23,7 @@ class GridSearch(BaseSearchClass): tex_labels = {'F0': '$f$', 'F1': '$\dot{f}$', 'F2': '$\ddot{f}$', 'Alpha': r'$\alpha$', 'Delta': r'$\delta$'} tex_labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$', 'F2': '$-\ddot{f}_0$', - 'Alpha': r'$-\alpha_0$', 'Deltas': r'$-\delta_0$'} + 'Alpha': r'$-\alpha_0$', 'Delta': r'$-\delta_0$'} @helper_functions.initializer def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas, @@ -100,8 +100,10 @@ class GridSearch(BaseSearchClass): self.F1s, self.F2s, self.Alphas, self.Deltas): coord_arrays.append(self.get_array_from_tuple(tup)) - self.input_data_generator_len = np.prod([len(k) for k in coord_arrays]) - self.input_data_generator = itertools.product(*coord_arrays) + input_data = [] + for vals in itertools.product(*coord_arrays): + input_data.append(vals) + self.input_data = np.array(input_data) self.coord_arrays = coord_arrays def check_old_data_is_okay_to_use(self): @@ -118,7 +120,15 @@ class GridSearch(BaseSearchClass): + ' continuing with grid search') return False - logging.info('No data caching available') + data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' ')) + if np.all(data[:, 0:-1] == self.input_data): + logging.info( + 'Old data found with matching input, no search performed') + return data + else: + logging.info( + 'Old data found, input differs, continuing with grid search') + return False return False def run(self, return_data=False): @@ -132,8 +142,7 @@ class GridSearch(BaseSearchClass): self.inititate_search_object() data = [] - for vals in tqdm(self.input_data_generator, - total=self.input_data_generator_len): + for vals in tqdm(self.input_data): FS = self.search.get_det_stat(*vals) data.append(list(vals) + [FS]) @@ -176,6 +185,8 @@ class GridSearch(BaseSearchClass): fig, ax = plt.subplots() xidx = self.keys.index(xkey) x = np.unique(self.data[:, xidx]) + if x0: + x = x - x0 z = self.data[:, -1] ax.plot(x, z) if x0: @@ -335,10 +346,9 @@ class SliceGridSearch(GridSearch): raise ValueError( 'Lambda0 must be of length {}'.format(len(self.search_keys))) - def run(self, factor=2): + def run(self, factor=2, max_n_ticks=4, whspace=0.07, save=True): lbdim = 0.5 * factor # size of left/bottom margin - trdim = 0.2 * factor # size of top/right margin - whspace = 0.05 # w/hspace size + trdim = 0.4 * factor # size of top/right margin plotdim = factor * self.ndim + factor * (self.ndim - 1.) * whspace dim = lbdim + plotdim + trdim @@ -353,13 +363,21 @@ class SliceGridSearch(GridSearch): search = GridSearch( self.label, self.outdir, self.sftfilepattern, F0s=self.Lambda0[0], F1s=self.Lambda0[1], F2s=self.F2s[0], - Alphas=self.Lambda0[2], Deltas=self.Lambda0[3], tref=self.tref) + Alphas=self.Lambda0[2], Deltas=self.Lambda0[3], tref=self.tref, + minStartTime=self.minStartTime, maxStartTime=self.maxStartTime) for i, ikey in enumerate(self.search_keys): setattr(search, ikey+'s', self.thetas[i]) + search.label = '{}_{}'.format(self.label, ikey) + search.set_out_file() search.run() - axes[i, i] = search.plot_1D(ikey, ax=axes[i, i], savefig=False) + axes[i, i] = search.plot_1D(ikey, ax=axes[i, i], savefig=False, + x0=self.Lambda0[i] + ) setattr(search, ikey+'s', [self.Lambda0[i]]) + axes[i, i].yaxis.tick_right() + axes[i, i].yaxis.set_label_position("right") + axes[i, i].set_xlabel('') for j, jkey in enumerate(self.search_keys): ax = axes[i, j] @@ -380,7 +398,6 @@ class SliceGridSearch(GridSearch): if j == i: continue - max_n_ticks = 3 ax.xaxis.set_major_locator( matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper")) ax.yaxis.set_major_locator( @@ -388,13 +405,32 @@ class SliceGridSearch(GridSearch): setattr(search, ikey+'s', self.thetas[i]) setattr(search, jkey+'s', self.thetas[j]) + search.label = '{}_{}'.format(self.label, ikey+jkey) + search.set_out_file() search.run() - ax = search.plot_2D(jkey, ikey, ax=ax, save=False) + ax = search.plot_2D(jkey, ikey, ax=ax, save=False, + y0=self.Lambda0[i], x0=self.Lambda0[j] + ) setattr(search, ikey+'s', [self.Lambda0[i]]) setattr(search, jkey+'s', [self.Lambda0[j]]) - fig.savefig( - '{}/{}_slice_projection.png'.format(self.outdir, self.label)) + ax.grid(lw=0.2, ls='--', zorder=10) + ax.set_xlabel('') + ax.set_ylabel('') + + for i, ikey in enumerate(self.search_keys): + axes[-1, i].set_xlabel( + self.tex_labels[ikey]+self.tex_labels0[ikey]) + if i > 0: + axes[i, 0].set_ylabel( + self.tex_labels[ikey]+self.tex_labels0[ikey]) + axes[i, i].set_ylabel("$2\mathcal{F}$") + + if save: + fig.savefig( + '{}/{}_slice_projection.png'.format(self.outdir, self.label)) + else: + return fig, axes class GridUniformPriorSearch():