""" Searches using grid-based methods """ import os import logging import itertools from collections import OrderedDict import numpy as np import matplotlib import matplotlib.pyplot as plt import helper_functions from core import BaseSearchClass, ComputeFstat, SemiCoherentGlitchSearch, SemiCoherentSearch from core import tqdm, args, earth_ephem, sun_ephem, read_par class GridSearch(BaseSearchClass): """ Gridded search using ComputeFstat """ @helper_functions.initializer def __init__(self, label, outdir, sftfilepath, F0s=[0], F1s=[0], F2s=[0], Alphas=[0], Deltas=[0], tref=None, minStartTime=None, maxStartTime=None, nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None, earth_ephem=None, sun_ephem=None, detectors=None, SSBprec=None, injectSources=None, input_arrays=False, assumeSqrtSX=None): """ Parameters ---------- label, outdir: str A label and directory to read/write data from/to sftfilepath: str File patern to match SFTs F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple Length 3 tuple describing the grid for each parameter, e.g [F0min, F0max, dF0], for a fixed value simply give [F0]. Unless input_arrays == True, then these are the values to search at. tref, minStartTime, maxStartTime: int GPS seconds of the reference time, start time and end time input_arrays: bool if true, use the F0s, F1s, etc as is For all other parameters, see `pyfstat.ComputeFStat` for details """ if earth_ephem is None: self.earth_ephem = self.earth_ephem_default if sun_ephem is None: self.sun_ephem = self.sun_ephem_default if os.path.isdir(outdir) is False: os.mkdir(outdir) self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label) self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta'] def inititate_search_object(self): logging.info('Setting up search object') if self.nsegs == 1: self.search = ComputeFstat( tref=self.tref, sftfilepath=self.sftfilepath, minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, detectors=self.detectors, transient=False, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, BSGL=self.BSGL, SSBprec=self.SSBprec, injectSources=self.injectSources, assumeSqrtSX=self.assumeSqrtSX) self.search.get_det_stat = self.search.run_computefstatistic_single_point else: self.search = SemiCoherentSearch( label=self.label, outdir=self.outdir, tref=self.tref, nsegs=self.nsegs, sftfilepath=self.sftfilepath, BSGL=self.BSGL, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq, detectors=self.detectors, earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem) def cut_out_tstart_tend(*vals): return self.search.run_semi_coherent_computefstatistic_single_point(*vals[2:]) self.search.get_det_stat = cut_out_tstart_tend def get_array_from_tuple(self, x): if len(x) == 1: return np.array(x) elif len(x) == 3 and self.input_arrays is False: return np.arange(x[0], x[1], x[2]) else: logging.info('Using tuple as is') return np.array(x) def get_input_data_array(self): arrays = [] for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas): arrays.append(self.get_array_from_tuple(tup)) input_data = [] for vals in itertools.product(*arrays): input_data.append(vals) self.arrays = arrays self.input_data = np.array(input_data) def check_old_data_is_okay_to_use(self): if args.clean: return False if os.path.isfile(self.out_file) is False: logging.info('No old data found, continuing with grid search') return False data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' ')) if np.all(data[:, 0:-1] == self.input_data): logging.info( 'Old data found with matching input, no search performed') return data else: logging.info( 'Old data found, input differs, continuing with grid search') return False def run(self, return_data=False): self.get_input_data_array() old_data = self.check_old_data_is_okay_to_use() if old_data is not False: self.data = old_data return self.inititate_search_object() logging.info('Total number of grid points is {}'.format( len(self.input_data))) data = [] for vals in tqdm(self.input_data): FS = self.search.get_det_stat(*vals) data.append(list(vals) + [FS]) data = np.array(data, dtype=np.float) if return_data: return data else: logging.info('Saving data to {}'.format(self.out_file)) np.savetxt(self.out_file, data, delimiter=' ') self.data = data def convert_F0_to_mismatch(self, F0, F0hat, Tseg): DeltaF0 = F0[1] - F0[0] m_spacing = (np.pi*Tseg*DeltaF0)**2 / 12. N = len(F0) return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing) def convert_F1_to_mismatch(self, F1, F1hat, Tseg): DeltaF1 = F1[1] - F1[0] m_spacing = (np.pi*Tseg**2*DeltaF1)**2 / 720. N = len(F1) return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing) def add_mismatch_to_ax(self, ax, x, y, xkey, ykey, xhat, yhat, Tseg): axX = ax.twiny() axX.zorder = -10 axY = ax.twinx() axY.zorder = -10 if xkey == 'F0': m = self.convert_F0_to_mismatch(x, xhat, Tseg) axX.set_xlim(m[0], m[-1]) if ykey == 'F1': m = self.convert_F1_to_mismatch(y, yhat, Tseg) axY.set_ylim(m[0], m[-1]) def plot_1D(self, xkey): fig, ax = plt.subplots() xidx = self.keys.index(xkey) x = np.unique(self.data[:, xidx]) z = self.data[:, -1] plt.plot(x, z) fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, add_mismatch=None, xN=None, yN=None, flat_keys=[], rel_flat_idxs=[], flatten_method=np.max, title=None, predicted_twoF=None, cm=None, cbarkwargs={}): """ Plots a 2D grid of 2F values Parameters ---------- add_mismatch: tuple (xhat, yhat, Tseg) If not None, add a secondary axis with the metric mismatch from the point xhat, yhat with duration Tseg flatten_method: np.max Function to use in flattening the flat_keys """ if ax is None: fig, ax = plt.subplots() xidx = self.keys.index(xkey) yidx = self.keys.index(ykey) flat_idxs = [self.keys.index(k) for k in flat_keys] x = np.unique(self.data[:, xidx]) y = np.unique(self.data[:, yidx]) flat_vals = [np.unique(self.data[:, j]) for j in flat_idxs] z = self.data[:, -1] Y, X = np.meshgrid(y, x) shape = [len(x), len(y)] + [len(v) for v in flat_vals] Z = z.reshape(shape) if len(rel_flat_idxs) > 0: Z = flatten_method(Z, axis=tuple(rel_flat_idxs)) if predicted_twoF: Z = (predicted_twoF - Z) / (predicted_twoF + 4) if cm is None: cm = plt.cm.viridis_r else: if cm is None: cm = plt.cm.viridis pax = ax.pcolormesh(X, Y, Z, cmap=cm, vmin=vmin, vmax=vmax) cb = plt.colorbar(pax, ax=ax, **cbarkwargs) cb.set_label('$2\mathcal{F}$') if add_mismatch: self.add_mismatch_to_ax(ax, x, y, xkey, ykey, *add_mismatch) ax.set_xlim(x[0], x[-1]) ax.set_ylim(y[0], y[-1]) labels = {'F0': '$f$', 'F1': '$\dot{f}$'} ax.set_xlabel(labels[xkey]) ax.set_ylabel(labels[ykey]) if title: ax.set_title(title) if xN: ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(xN)) if yN: ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(yN)) if save: fig.tight_layout() fig.savefig('{}/{}_2D.png'.format(self.outdir, self.label)) else: return ax def get_max_twoF(self): twoF = self.data[:, -1] idx = np.argmax(twoF) v = self.data[idx, :] d = OrderedDict(minStartTime=v[0], maxStartTime=v[1], F0=v[2], F1=v[3], F2=v[4], Alpha=v[5], Delta=v[6], twoF=v[7]) return d def print_max_twoF(self): d = self.get_max_twoF() print('Max twoF values for {}:'.format(self.label)) for k, v in d.iteritems(): print(' {}={}'.format(k, v)) class GridUniformPriorSearch(): def __init__(self, theta_prior, NF0, NF1, label, outdir, sftfilepath, tref, minStartTime, maxStartTime, minCoverFreq=None, maxCoverFreq=None, BSGL=False, detectors=None, nsegs=1): dF0 = (theta_prior['F0']['upper'] - theta_prior['F0']['lower'])/NF0 dF1 = (theta_prior['F1']['upper'] - theta_prior['F1']['lower'])/NF1 F0s = [theta_prior['F0']['lower'], theta_prior['F0']['upper'], dF0] F1s = [theta_prior['F1']['lower'], theta_prior['F1']['upper'], dF1] self.search = GridSearch( label, outdir, sftfilepath, F0s=F0s, F1s=F1s, tref=tref, Alphas=[theta_prior['Alpha']], Deltas=[theta_prior['Delta']], minStartTime=minStartTime, maxStartTime=maxStartTime, BSGL=BSGL, detectors=detectors, minCoverFreq=minCoverFreq, maxCoverFreq=maxCoverFreq, nsegs=nsegs) def run(self, **kwargs): self.search.run() return self.search.plot_2D('F0', 'F1', **kwargs) class GridGlitchSearch(GridSearch): """ Grid search using the SemiCoherentGlitchSearch """ @helper_functions.initializer def __init__(self, label, outdir, sftfilepath=None, F0s=[0], F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None, Alphas=[0], Deltas=[0], tref=None, minStartTime=None, maxStartTime=None, minCoverFreq=None, maxCoverFreq=None, write_after=1000, earth_ephem=None, sun_ephem=None): """ Parameters ---------- label, outdir: str A label and directory to read/write data from/to sftfilepath: str File patern to match SFTs F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple Length 3 tuple describing the grid for each parameter, e.g [F0min, F0max, dF0], for a fixed value simply give [F0]. tref, minStartTime, maxStartTime: int GPS seconds of the reference time, start time and end time For all other parameters, see pyfstat.ComputeFStat. """ if tglitchs is None: self.tglitchs = [self.maxStartTime] if earth_ephem is None: self.earth_ephem = self.earth_ephem_default if sun_ephem is None: self.sun_ephem = self.sun_ephem_default self.search = SemiCoherentGlitchSearch( label=label, outdir=outdir, sftfilepath=self.sftfilepath, tref=tref, minStartTime=minStartTime, maxStartTime=maxStartTime, minCoverFreq=minCoverFreq, maxCoverFreq=maxCoverFreq, earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, BSGL=self.BSGL) if os.path.isdir(outdir) is False: os.mkdir(outdir) self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label) self.keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta', 'delta_F0', 'delta_F1', 'tglitch'] def get_input_data_array(self): arrays = [] for tup in (self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas, self.delta_F0s, self.delta_F1s, self.tglitchs): arrays.append(self.get_array_from_tuple(tup)) input_data = [] for vals in itertools.product(*arrays): input_data.append(vals) self.arrays = arrays self.input_data = np.array(input_data) class FrequencySlidingWindow(GridSearch): """ A sliding-window search over the Frequency """ @helper_functions.initializer def __init__(self, label, outdir, sftfilepath, F0s, F1, F2, Alpha, Delta, tref, minStartTime=None, maxStartTime=None, window_size=10*86400, window_delta=86400, BSGL=False, minCoverFreq=None, maxCoverFreq=None, earth_ephem=None, sun_ephem=None, detectors=None, SSBprec=None, injectSources=None): """ Parameters ---------- label, outdir: str A label and directory to read/write data from/to sftfilepath: str File patern to match SFTs F0s: array Frequency range F1, F2, Alpha, Delta: float Fixed values to compute twoF(F) over tref, minStartTime, maxStartTime: int GPS seconds of the reference time, start time and end time For all other parameters, see `pyfstat.ComputeFStat` for details """ if earth_ephem is None: self.earth_ephem = self.earth_ephem_default if sun_ephem is None: self.sun_ephem = self.sun_ephem_default if os.path.isdir(outdir) is False: os.mkdir(outdir) self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label) self.nsegs = 1 self.F1s = [F1] self.F2s = [F2] self.Alphas = [Alpha] self.Deltas = [Delta] def inititate_search_object(self): logging.info('Setting up search object') self.search = ComputeFstat( tref=self.tref, sftfilepath=self.sftfilepath, minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, detectors=self.detectors, transient=True, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, BSGL=self.BSGL, SSBprec=self.SSBprec, injectSources=self.injectSources) self.search.get_det_stat = ( self.search.run_computefstatistic_single_point) def get_input_data_array(self): arrays = [] tstarts = [self.minStartTime] while tstarts[-1] + self.window_size < self.maxStartTime: tstarts.append(tstarts[-1]+self.window_delta) arrays = [tstarts] for tup in (self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas): arrays.append(self.get_array_from_tuple(tup)) input_data = [] for vals in itertools.product(*arrays): input_data.append(vals) input_data = np.array(input_data) input_data = np.insert( input_data, 1, input_data[:, 0] + self.window_size, axis=1) self.arrays = arrays self.input_data = np.array(input_data) def plot_sliding_window(self, F0=None, ax=None, savefig=True, colorbar=True, timestamps=False): data = self.data if ax is None: ax = plt.subplot() tstarts = np.unique(data[:, 0]) tends = np.unique(data[:, 1]) frequencies = np.unique(data[:, 2]) twoF = data[:, -1] tmids = (tstarts + tends) / 2.0 dts = (tmids - self.minStartTime) / 86400. if F0: frequencies = frequencies - F0 ax.set_ylabel('Frequency - $f_0$ [Hz] \n $f_0={:0.2f}$'.format(F0)) else: ax.set_ylabel('Frequency [Hz]') twoF = twoF.reshape((len(tmids), len(frequencies))) Y, X = np.meshgrid(frequencies, dts) pax = ax.pcolormesh(X, Y, twoF) if colorbar: cb = plt.colorbar(pax, ax=ax) cb.set_label('$2\mathcal{F}$') ax.set_xlabel( r'Mid-point (days after $t_\mathrm{{start}}$={})'.format( self.minStartTime)) ax.set_title( 'Sliding window length = {} days in increments of {} days' .format(self.window_size/86400, self.window_delta/86400), ) if timestamps: axT = ax.twiny() axT.set_xlim(tmids[0]*1e-9, tmids[-1]*1e-9) axT.set_xlabel('Mid-point timestamp [GPS $10^{9}$ s]') ax.set_title(ax.get_title(), y=1.18) if savefig: plt.tight_layout() plt.savefig( '{}/{}_sliding_window.png'.format(self.outdir, self.label)) else: return ax class DMoff_NO_SPIN(GridSearch): """ DMoff test using SSBPREC_NO_SPIN """ @helper_functions.initializer def __init__(self, par, label, outdir, sftfilepath, minStartTime=None, maxStartTime=None, minCoverFreq=None, maxCoverFreq=None, earth_ephem=None, sun_ephem=None, detectors=None, injectSources=None, assumeSqrtSX=None): """ Parameters ---------- par: dict, str Either a par dictionary (containing 'F0', 'F1', 'Alpha', 'Delta' and 'tref') or a path to a .par file to read in the F0, F1 etc label, outdir: str A label and directory to read/write data from/to sftfilepath: str File patern to match SFTs minStartTime, maxStartTime: int GPS seconds of the start time and end time For all other parameters, see `pyfstat.ComputeFStat` for details """ if earth_ephem is None: self.earth_ephem = self.earth_ephem_default if sun_ephem is None: self.sun_ephem = self.sun_ephem_default if os.path.isdir(outdir) is False: os.mkdir(outdir) if type(par) == dict: self.par = par elif type(par) == str and os.path.isfile(par): self.par = read_par(filename=par) else: raise ValueError('The .par file does not exist') self.nsegs = 1 self.BSGL = False self.tref = self.par['tref'] self.F1s = [self.par.get('F1', 0)] self.F2s = [self.par.get('F2', 0)] self.Alphas = [self.par['Alpha']] self.Deltas = [self.par['Delta']] self.Re = 6.371e6 self.c = 2.998e8 self.SIDEREAL_DAY = 23*60*60 + 56*60 + 4.0916 self.TERRESTRIAL_DAY = 86400. a0 = self.Re/self.c # *np.cos(self.par['Delta']) self.m0 = np.max([4, int(np.ceil(2*np.pi*self.par['F0']*a0))]) logging.info('m0 = {}'.format(self.m0)) def get_results(self): """ Compute the three summed detection statistics Returns ------- m0, twoF_SUM, twoFstar_SUM_SIDEREAL, twoFstar_SUM_TERRESTRIAL """ self.SSBprec = 2 self.out_file = '{}/{}_gridFS_SSBPREC2.txt'.format( self.outdir, self.label) self.F0s = [self.par['F0']+j/self.SIDEREAL_DAY for j in range(-self.m0, self.m0+1)] self.run() twoF_SUM = np.sum(self.data[:, -1]) self.SSBprec = 4 self.out_file = '{}/{}_gridFS_SSBPREC4_SIDEREAL.txt'.format( self.outdir, self.label) self.F0s = [self.par['F0']+j/self.SIDEREAL_DAY for j in range(-self.m0, self.m0+1)] self.run() twoFstar_SUM = np.sum(self.data[:, -1]) self.out_file = '{}/{}_gridFS_SSBPREC4_TERRESTIAL.txt'.format( self.outdir, self.label) self.F0s = [self.par['F0']+j/self.TERRESTRIAL_DAY for j in range(-self.m0, self.m0+1)] self.run() twoFstar_SUM_terrestrial = np.sum(self.data[:, -1]) return self.m0, twoF_SUM, twoFstar_SUM, twoFstar_SUM_terrestrial