""" Searches using grid-based methods """ import os import logging import itertools from collections import OrderedDict import numpy as np import matplotlib import matplotlib.pyplot as plt import helper_functions from core import BaseSearchClass, ComputeFstat, SemiCoherentGlitchSearch from core import tqdm, args, earth_ephem, sun_ephem class GridSearch(BaseSearchClass): """ Gridded search using ComputeFstat """ @helper_functions.initializer def __init__(self, label, outdir, sftfilepath, F0s=[0], F1s=[0], F2s=[0], Alphas=[0], Deltas=[0], tref=None, minStartTime=None, maxStartTime=None, BSGL=False, minCoverFreq=None, maxCoverFreq=None, earth_ephem=None, sun_ephem=None, detector=None): """ Parameters ---------- label, outdir: str A label and directory to read/write data from/to sftfilepath: str File patern to match SFTs F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple Length 3 tuple describing the grid for each parameter, e.g [F0min, F0max, dF0], for a fixed value simply give [F0]. tref, minStartTime, maxStartTime: int GPS seconds of the reference time, start time and end time For all other parameters, see `pyfstat.ComputeFStat` for details """ if earth_ephem is None: self.earth_ephem = self.earth_ephem_default if sun_ephem is None: self.sun_ephem = self.sun_ephem_default if os.path.isdir(outdir) is False: os.mkdir(outdir) self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label) self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta'] def inititate_search_object(self): logging.info('Setting up search object') self.search = ComputeFstat( tref=self.tref, sftfilepath=self.sftfilepath, minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, detector=self.detector, transient=False, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, BSGL=self.BSGL) def get_array_from_tuple(self, x): if len(x) == 1: return np.array(x) else: return np.arange(x[0], x[1]*(1+1e-15), x[2]) def get_input_data_array(self): arrays = [] for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas): arrays.append(self.get_array_from_tuple(tup)) input_data = [] for vals in itertools.product(*arrays): input_data.append(vals) self.arrays = arrays self.input_data = np.array(input_data) def check_old_data_is_okay_to_use(self): if args.clean: return False if os.path.isfile(self.out_file) is False: logging.info('No old data found, continuing with grid search') return False data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' ')) if np.all(data[:, 0:-1] == self.input_data): logging.info( 'Old data found with matching input, no search performed') return data else: logging.info( 'Old data found, input differs, continuing with grid search') return False def run(self, return_data=False): self.get_input_data_array() old_data = self.check_old_data_is_okay_to_use() if old_data is not False: self.data = old_data return self.inititate_search_object() logging.info('Total number of grid points is {}'.format( len(self.input_data))) data = [] for vals in tqdm(self.input_data): FS = self.search.run_computefstatistic_single_point(*vals) data.append(list(vals) + [FS]) data = np.array(data) if return_data: return data else: logging.info('Saving data to {}'.format(self.out_file)) np.savetxt(self.out_file, data, delimiter=' ') self.data = data def convert_F0_to_mismatch(self, F0, F0hat, Tseg): DeltaF0 = F0[1] - F0[0] m_spacing = (np.pi*Tseg*DeltaF0)**2 / 12. N = len(F0) return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing) def convert_F1_to_mismatch(self, F1, F1hat, Tseg): DeltaF1 = F1[1] - F1[0] m_spacing = (np.pi*Tseg**2*DeltaF1)**2 / 720. N = len(F1) return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing) def add_mismatch_to_ax(self, ax, x, y, xkey, ykey, xhat, yhat, Tseg): axX = ax.twiny() axX.zorder = -10 axY = ax.twinx() axY.zorder = -10 if xkey == 'F0': m = self.convert_F0_to_mismatch(x, xhat, Tseg) axX.set_xlim(m[0], m[-1]) if ykey == 'F1': m = self.convert_F1_to_mismatch(y, yhat, Tseg) axY.set_ylim(m[0], m[-1]) def plot_1D(self, xkey): fig, ax = plt.subplots() xidx = self.keys.index(xkey) x = np.unique(self.data[:, xidx]) z = self.data[:, -1] plt.plot(x, z) fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, add_mismatch=None, xN=None, yN=None, flat_keys=[], rel_flat_idxs=[], flatten_method=np.max, predicted_twoF=None, cm=None, cbarkwargs={}): """ Plots a 2D grid of 2F values Parameters ---------- add_mismatch: tuple (xhat, yhat, Tseg) If not None, add a secondary axis with the metric mismatch from the point xhat, yhat with duration Tseg flatten_method: np.max Function to use in flattening the flat_keys """ if ax is None: fig, ax = plt.subplots() xidx = self.keys.index(xkey) yidx = self.keys.index(ykey) flat_idxs = [self.keys.index(k) for k in flat_keys] x = np.unique(self.data[:, xidx]) y = np.unique(self.data[:, yidx]) flat_vals = [np.unique(self.data[:, j]) for j in flat_idxs] z = self.data[:, -1] Y, X = np.meshgrid(y, x) shape = [len(x), len(y)] + [len(v) for v in flat_vals] Z = z.reshape(shape) if len(rel_flat_idxs) > 0: Z = flatten_method(Z, axis=tuple(rel_flat_idxs)) if predicted_twoF: Z = (predicted_twoF - Z) / (predicted_twoF + 4) if cm is None: cm = plt.cm.viridis_r else: if cm is None: cm = plt.cm.viridis pax = ax.pcolormesh(X, Y, Z, cmap=cm, vmin=vmin, vmax=vmax) cb = plt.colorbar(pax, ax=ax, **cbarkwargs) cb.set_label('$2\mathcal{F}$') if add_mismatch: self.add_mismatch_to_ax(ax, x, y, xkey, ykey, *add_mismatch) ax.set_xlim(x[0], x[-1]) ax.set_ylim(y[0], y[-1]) labels = {'F0': '$f$', 'F1': '$\dot{f}$'} ax.set_xlabel(labels[xkey]) ax.set_ylabel(labels[ykey]) if xN: ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(xN)) if yN: ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(yN)) if save: fig.tight_layout() fig.savefig('{}/{}_2D.png'.format(self.outdir, self.label)) else: return ax def get_max_twoF(self): twoF = self.data[:, -1] idx = np.argmax(twoF) v = self.data[idx, :] d = OrderedDict(minStartTime=v[0], maxStartTime=v[1], F0=v[2], F1=v[3], F2=v[4], Alpha=v[5], Delta=v[6], twoF=v[7]) return d def print_max_twoF(self): d = self.get_max_twoF() print('Max twoF values for {}:'.format(self.label)) for k, v in d.iteritems(): print(' {}={}'.format(k, v)) class GridGlitchSearch(GridSearch): """ Grid search using the SemiCoherentGlitchSearch """ @helper_functions.initializer def __init__(self, label, outdir, sftfilepath=None, F0s=[0], F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None, Alphas=[0], Deltas=[0], tref=None, minStartTime=None, maxStartTime=None, minCoverFreq=None, maxCoverFreq=None, write_after=1000, earth_ephem=None, sun_ephem=None): """ Parameters ---------- label, outdir: str A label and directory to read/write data from/to sftfilepath: str File patern to match SFTs F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple Length 3 tuple describing the grid for each parameter, e.g [F0min, F0max, dF0], for a fixed value simply give [F0]. tref, minStartTime, maxStartTime: int GPS seconds of the reference time, start time and end time For all other parameters, see pyfstat.ComputeFStat. """ if tglitchs is None: self.tglitchs = [self.maxStartTime] if earth_ephem is None: self.earth_ephem = self.earth_ephem_default if sun_ephem is None: self.sun_ephem = self.sun_ephem_default self.search = SemiCoherentGlitchSearch( label=label, outdir=outdir, sftfilepath=self.sftfilepath, tref=tref, minStartTime=minStartTime, maxStartTime=maxStartTime, minCoverFreq=minCoverFreq, maxCoverFreq=maxCoverFreq, earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, BSGL=self.BSGL) if os.path.isdir(outdir) is False: os.mkdir(outdir) self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label) self.keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta', 'delta_F0', 'delta_F1', 'tglitch'] def get_input_data_array(self): arrays = [] for tup in (self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas, self.delta_F0s, self.delta_F1s, self.tglitchs): arrays.append(self.get_array_from_tuple(tup)) input_data = [] for vals in itertools.product(*arrays): input_data.append(vals) self.arrays = arrays self.input_data = np.array(input_data)