diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py index d231961df1465f14e231c479f78b305695a1a50b..a7eebcc942a20e72259b7fea41e7b406ecc15089 100644 --- a/pyfstat/grid_based_searches.py +++ b/pyfstat/grid_based_searches.py @@ -20,12 +20,17 @@ import lal class GridSearch(BaseSearchClass): """ Gridded search using ComputeFstat """ + tex_labels = {'F0': '$f$', 'F1': '$\dot{f}$', 'F2': '$\ddot{f}$', + 'Alpha': r'$\alpha$', 'Delta': r'$\delta$'} + tex_labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$', 'F2': '$-\ddot{f}_0$', + 'Alpha': r'$-\alpha_0$', 'Deltas': r'$-\delta_0$'} + @helper_functions.initializer - def __init__(self, label, outdir, sftfilepattern, F0s=[0], F1s=[0], F2s=[0], - Alphas=[0], Deltas=[0], tref=None, minStartTime=None, - maxStartTime=None, nsegs=1, BSGL=False, minCoverFreq=None, - maxCoverFreq=None, detectors=None, SSBprec=None, - injectSources=None, input_arrays=False, assumeSqrtSX=None): + def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas, + Deltas, tref=None, minStartTime=None, maxStartTime=None, + nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None, + detectors=None, SSBprec=None, injectSources=None, + input_arrays=False, assumeSqrtSX=None): """ Parameters ---------- @@ -50,6 +55,9 @@ class GridSearch(BaseSearchClass): os.mkdir(outdir) self.set_out_file() self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta'] + self.search_keys = [x+'s' for x in self.keys[2:]] + for k in self.search_keys: + setattr(self, k, np.atleast_1d(getattr(self, k))) def inititate_search_object(self): logging.info('Setting up search object') @@ -88,8 +96,8 @@ class GridSearch(BaseSearchClass): def get_input_data_array(self): logging.info("Generating input data array") coord_arrays = [] - for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s, - self.Alphas, self.Deltas): + for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, + self.F1s, self.F2s, self.Alphas, self.Deltas): coord_arrays.append(self.get_array_from_tuple(tup)) self.input_data_generator_len = np.prod([len(k) for k in coord_arrays]) @@ -112,15 +120,6 @@ class GridSearch(BaseSearchClass): logging.info('No data caching available') return False - #data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' ')) - #if np.all(data[:, 0:-1] == self.input_data): - # logging.info( - # 'Old data found with matching input, no search performed') - # return data - #else: - # logging.info( - # 'Old data found, input differs, continuing with grid search') - # return False def run(self, return_data=False): self.get_input_data_array() @@ -129,10 +128,12 @@ class GridSearch(BaseSearchClass): self.data = old_data return - self.inititate_search_object() + if hasattr(self, 'search') is False: + self.inititate_search_object() data = [] - for vals in tqdm(self.input_data_generator): + for vals in tqdm(self.input_data_generator, + total=self.input_data_generator_len): FS = self.search.get_det_stat(*vals) data.append(list(vals) + [FS]) @@ -170,18 +171,27 @@ class GridSearch(BaseSearchClass): m = self.convert_F1_to_mismatch(y, yhat, Tseg) axY.set_ylim(m[0], m[-1]) - def plot_1D(self, xkey): - fig, ax = plt.subplots() + def plot_1D(self, xkey, ax=None, x0=None, savefig=True): + if ax is None: + fig, ax = plt.subplots() xidx = self.keys.index(xkey) x = np.unique(self.data[:, xidx]) z = self.data[:, -1] - plt.plot(x, z) - fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) + ax.plot(x, z) + if x0: + ax.set_xlabel(self.tex_labels[xkey]+self.tex_labels0[xkey]) + else: + ax.set_xlabel(self.tex_labels[xkey]) + if savefig: + fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) + else: + return ax def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, add_mismatch=None, xN=None, yN=None, flat_keys=[], rel_flat_idxs=[], flatten_method=np.max, title=None, - predicted_twoF=None, cm=None, cbarkwargs={}, x0=None, y0=None): + predicted_twoF=None, cm=None, cbarkwargs={}, x0=None, y0=None, + colorbar=False): """ Plots a 2D grid of 2F values Parameters @@ -223,24 +233,23 @@ class GridSearch(BaseSearchClass): cm = plt.cm.viridis pax = ax.pcolormesh(X, Y, Z, cmap=cm, vmin=vmin, vmax=vmax) - cb = plt.colorbar(pax, ax=ax, **cbarkwargs) - cb.set_label('$2\mathcal{F}$') + if colorbar: + cb = plt.colorbar(pax, ax=ax, **cbarkwargs) + cb.set_label('$2\mathcal{F}$') if add_mismatch: self.add_mismatch_to_ax(ax, x, y, xkey, ykey, *add_mismatch) ax.set_xlim(x[0], x[-1]) ax.set_ylim(y[0], y[-1]) - labels = {'F0': '$f$', 'F1': '$\dot{f}$'} - labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$'} if x0: - ax.set_xlabel(labels[xkey]+labels0[xkey]) + ax.set_xlabel(self.tex_labels[xkey]+self.tex_labels0[xkey]) else: - ax.set_xlabel(labels[xkey]) + ax.set_xlabel(self.tex_labels[xkey]) if y0: - ax.set_ylabel(labels[ykey]+labels0[ykey]) + ax.set_ylabel(self.tex_labels[ykey]+self.tex_labels0[ykey]) else: - ax.set_ylabel(labels[ykey]) + ax.set_ylabel(self.tex_labels[ykey]) if title: ax.set_title(title) @@ -287,12 +296,11 @@ class GridSearch(BaseSearchClass): class SliceGridSearch(GridSearch): """ Slice gridded search using ComputeFstat """ @helper_functions.initializer - def __init__(self, label, outdir, sftfilepattern, F0s=[0], F1s=[0], F2s=[0], - Alphas=[0], Deltas=[0], tref=None, minStartTime=None, - maxStartTime=None, nsegs=1, BSGL=False, minCoverFreq=None, - maxCoverFreq=None, detectors=None, SSBprec=None, - injectSources=None, input_arrays=False, assumeSqrtSX=None, - Lambda0=None): + def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas, + Deltas, tref=None, minStartTime=None, maxStartTime=None, + nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None, + detectors=None, SSBprec=None, injectSources=None, + input_arrays=False, assumeSqrtSX=None, Lambda0=None): """ Parameters ---------- @@ -317,46 +325,76 @@ class SliceGridSearch(GridSearch): os.mkdir(outdir) self.set_out_file() self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta'] + self.ndim = 0 + self.thetas = [F0s, F1s, Alphas, Deltas] + self.ndim = 4 + self.search_keys = ['F0', 'F1', 'Alpha', 'Delta'] self.Lambda0 = np.array(Lambda0) - if len(self.Lambda0) != len(self.keys): + if len(self.Lambda0) != len(self.search_keys): raise ValueError( - 'Lambda0 must be of length {}'.format(len(self.keys))) - - def run(self, return_data=False): - self.get_input_data_array() - - self.Lambda0s_grid = [] - for j, arr in enumerate(self.coord_arrays): - i = np.argmin(np.abs(self.Lambda0[j]-arr)) - self.Lambda0s_grid.append(arr[i]) - - old_data = self.check_old_data_is_okay_to_use() - if old_data is not False: - self.data = old_data - return - - self.inititate_search_object() - - logging.info('Total number of grid points is {}'.format( - self.input_data_generator_len)) - - data = [] - for vals in tqdm(self.input_data_generator, - total=self.input_data_generator_len): - if np.sum(np.array(vals) != np.array(self.Lambda0s_grid)) < 3: - FS = self.search.get_det_stat(*vals) - data.append(list(vals) + [FS]) - else: - data.append(list(vals) + [0]) - - data = np.array(data, dtype=np.float) - if return_data: - return data - else: - logging.info('Saving data to {}'.format(self.out_file)) - np.savetxt(self.out_file, data, delimiter=' ') - self.data = data + 'Lambda0 must be of length {}'.format(len(self.search_keys))) + + def run(self, factor=2): + lbdim = 0.5 * factor # size of left/bottom margin + trdim = 0.2 * factor # size of top/right margin + whspace = 0.05 # w/hspace size + plotdim = factor * self.ndim + factor * (self.ndim - 1.) * whspace + dim = lbdim + plotdim + trdim + + fig, axes = plt.subplots(self.ndim, self.ndim, figsize=(dim, dim)) + + # Format the figure. + lb = lbdim / dim + tr = (lbdim + plotdim) / dim + fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr, + wspace=whspace, hspace=whspace) + + search = GridSearch( + self.label, self.outdir, self.sftfilepattern, + F0s=self.Lambda0[0], F1s=self.Lambda0[1], F2s=self.F2s[0], + Alphas=self.Lambda0[2], Deltas=self.Lambda0[3], tref=self.tref) + + for i, ikey in enumerate(self.search_keys): + setattr(search, ikey+'s', self.thetas[i]) + search.run() + axes[i, i] = search.plot_1D(ikey, ax=axes[i, i], savefig=False) + setattr(search, ikey+'s', [self.Lambda0[i]]) + + for j, jkey in enumerate(self.search_keys): + ax = axes[i, j] + + if j > i: + ax.set_frame_on(False) + ax.set_xticks([]) + ax.set_yticks([]) + continue + + ax.get_shared_x_axes().join(axes[self.ndim-1, j], ax) + if i < self.ndim - 1: + ax.set_xticklabels([]) + if j < i: + ax.get_shared_y_axes().join(axes[i, i-1], ax) + if j > 0: + ax.set_yticklabels([]) + if j == i: + continue + + max_n_ticks = 3 + ax.xaxis.set_major_locator( + matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper")) + ax.yaxis.set_major_locator( + matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper")) + + setattr(search, ikey+'s', self.thetas[i]) + setattr(search, jkey+'s', self.thetas[j]) + search.run() + ax = search.plot_2D(jkey, ikey, ax=ax, save=False) + setattr(search, ikey+'s', [self.Lambda0[i]]) + setattr(search, jkey+'s', [self.Lambda0[j]]) + + fig.savefig( + '{}/{}_slice_projection.png'.format(self.outdir, self.label)) class GridUniformPriorSearch():