diff --git a/README.md b/README.md index fb83dd740197c701f1bd2724cb021211c0f1c8d7..abd92f7aa67fe13308ddf16d86441e4364e0fef7 100644 --- a/README.md +++ b/README.md @@ -40,11 +40,7 @@ are provided in the links. ### Dependencies -`pyfstat` makes use of a variety python modules listed as the -`imports` in the top of `pyfstat.py`. The first set are core modules (such as -`os`, `sys`) while the second set are external and need to be installed for -`pyfstat` to work properly. Please install the following widely available -modules: +`pyfstat` makes uses the following external python modules: * [numpy](http://www.numpy.org/) * [matplotlib](http://matplotlib.org/) diff --git a/pyfstat/__init__.py b/pyfstat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e6c40f0a2abebfbc522115f6fbf37e25508ed30 --- /dev/null +++ b/pyfstat/__init__.py @@ -0,0 +1,6 @@ +from __future__ import division + +from .core import BaseSearchClass, ComputeFstat, Writer +from .mcmc_based_searches import * +from .grid_based_searches import * + diff --git a/pyfstat/core.py b/pyfstat/core.py new file mode 100755 index 0000000000000000000000000000000000000000..7a23a08a158cc588fc8f969809f746922f15a80c --- /dev/null +++ b/pyfstat/core.py @@ -0,0 +1,912 @@ +""" The core tools used in pyfstat """ +import os +import logging +import copy +import glob +import subprocess + +import numpy as np +import matplotlib.pyplot as plt +import scipy.special +import scipy.optimize +import lal +import lalpulsar + +import helper_functions +tqdm = helper_functions.set_up_optional_tqdm() +helper_functions.set_up_matplotlib_defaults() +args = helper_functions.set_up_command_line_arguments() +earth_ephem, sun_ephem = helper_functions.set_up_ephemeris_configuration() + + +def read_par(label, outdir): + """ Read in a .par file, returns a dictionary of the values """ + filename = '{}/{}.par'.format(outdir, label) + d = {} + with open(filename, 'r') as f: + for line in f: + if len(line.split('=')) > 1: + key, val = line.rstrip('\n').split(' = ') + key = key.strip() + d[key] = np.float64(eval(val.rstrip('; '))) + return d + + +class BaseSearchClass(object): + """ The base search class, provides general functions """ + + earth_ephem_default = earth_ephem + sun_ephem_default = sun_ephem + + def add_log_file(self): + """ Log output to a file, requires class to have outdir and label """ + logfilename = '{}/{}.log'.format(self.outdir, self.label) + fh = logging.FileHandler(logfilename) + fh.setLevel(logging.INFO) + fh.setFormatter(logging.Formatter( + '%(asctime)s %(levelname)-8s: %(message)s', + datefmt='%y-%m-%d %H:%M')) + logging.getLogger().addHandler(fh) + + def shift_matrix(self, n, dT): + """ Generate the shift matrix + + Parameters + ---------- + n: int + The dimension of the shift-matrix to generate + dT: float + The time delta of the shift matrix + + Returns + ------- + m: array (n, n) + The shift matrix + """ + + m = np.zeros((n, n)) + factorial = np.math.factorial + for i in range(n): + for j in range(n): + if i == j: + m[i, j] = 1.0 + elif i > j: + m[i, j] = 0.0 + else: + if i == 0: + m[i, j] = 2*np.pi*float(dT)**(j-i) / factorial(j-i) + else: + m[i, j] = float(dT)**(j-i) / factorial(j-i) + return m + + def shift_coefficients(self, theta, dT): + """ Shift a set of coefficients by dT + + Parameters + ---------- + theta: array-like, shape (n,) + vector of the expansion coefficients to transform starting from the + lowest degree e.g [phi, F0, F1,...]. + dT: float + difference between the two reference times as tref_new - tref_old. + + Returns + ------- + theta_new: array-like shape (n,) + vector of the coefficients as evaluate as the new reference time. + """ + + n = len(theta) + m = self.shift_matrix(n, dT) + return np.dot(m, theta) + + def calculate_thetas(self, theta, delta_thetas, tbounds, theta0_idx=0): + """ Calculates the set of coefficients for the post-glitch signal """ + thetas = [theta] + for i, dt in enumerate(delta_thetas): + if i < theta0_idx: + pre_theta_at_ith_glitch = self.shift_coefficients( + thetas[0], tbounds[i+1] - self.tref) + post_theta_at_ith_glitch = pre_theta_at_ith_glitch - dt + thetas.insert(0, self.shift_coefficients( + post_theta_at_ith_glitch, self.tref - tbounds[i+1])) + + elif i >= theta0_idx: + pre_theta_at_ith_glitch = self.shift_coefficients( + thetas[i], tbounds[i+1] - self.tref) + post_theta_at_ith_glitch = pre_theta_at_ith_glitch + dt + thetas.append(self.shift_coefficients( + post_theta_at_ith_glitch, self.tref - tbounds[i+1])) + return thetas + + def generate_loudest(self): + params = read_par(self.label, self.outdir) + for key in ['Alpha', 'Delta', 'F0', 'F1']: + if key not in params: + params[key] = self.theta_prior[key] + cmd = ('lalapps_ComputeFstatistic_v2 -a {} -d {} -f {} -s {} -D "{}"' + ' --refTime={} --outputLoudest="{}/{}.loudest" ' + '--minStartTime={} --maxStartTime={}').format( + params['Alpha'], params['Delta'], params['F0'], + params['F1'], self.sftfilepath, params['tref'], + self.outdir, self.label, self.minStartTime, + self.maxStartTime) + subprocess.call([cmd], shell=True) + + +class ComputeFstat(object): + """ Base class providing interface to `lalpulsar.ComputeFstat` """ + + earth_ephem_default = earth_ephem + sun_ephem_default = sun_ephem + + @helper_functions.initializer + def __init__(self, tref, sftfilepath=None, minStartTime=None, + maxStartTime=None, binary=False, transient=True, BSGL=False, + detector=None, minCoverFreq=None, maxCoverFreq=None, + earth_ephem=None, sun_ephem=None, injectSources=None + ): + """ + Parameters + ---------- + tref: int + GPS seconds of the reference time. + sftfilepath: str + File patern to match SFTs + minStartTime, maxStartTime: float GPStime + Only use SFTs with timestemps starting from (including, excluding) + this epoch + binary: bool + If true, search of binary parameters. + transient: bool + If true, allow for the Fstat to be computed over a transient range. + BSGL: bool + If true, compute the BSGL rather than the twoF value. + detector: str + Two character reference to the data to use, specify None for no + contraint. + minCoverFreq, maxCoverFreq: float + The min and max cover frequency passed to CreateFstatInput, if + either is None the range of frequencies in the SFT less 1Hz is + used. + earth_ephem, sun_ephem: str + Paths of the two files containing positions of Earth and Sun, + respectively at evenly spaced times, as passed to CreateFstatInput. + If None defaults defined in BaseSearchClass will be used. + + """ + + if earth_ephem is None: + self.earth_ephem = self.earth_ephem_default + if sun_ephem is None: + self.sun_ephem = self.sun_ephem_default + + self.init_computefstatistic_single_point() + + def get_SFTCatalog(self): + if hasattr(self, 'SFTCatalog'): + return + logging.info('Initialising SFTCatalog') + constraints = lalpulsar.SFTConstraints() + if self.detector: + constraints.detector = self.detector + if self.minStartTime: + constraints.minStartTime = lal.LIGOTimeGPS(self.minStartTime) + if self.maxStartTime: + constraints.maxStartTime = lal.LIGOTimeGPS(self.maxStartTime) + + logging.info('Loading data matching pattern {}'.format( + self.sftfilepath)) + SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepath, constraints) + detector_names = list(set([d.header.name for d in SFTCatalog.data])) + self.detector_names = detector_names + SFT_timestamps = [d.header.epoch for d in SFTCatalog.data] + if args.quite is False and args.no_interactive is False: + try: + from bashplotlib.histogram import plot_hist + print('Data timestamps histogram:') + plot_hist(SFT_timestamps, height=5, bincount=50) + except IOError: + pass + if len(detector_names) == 0: + raise ValueError('No data loaded.') + logging.info('Loaded {} data files from detectors {}'.format( + len(SFT_timestamps), detector_names)) + logging.info('Data spans from {} ({}) to {} ({})'.format( + int(SFT_timestamps[0]), + subprocess.check_output('lalapps_tconvert {}'.format( + int(SFT_timestamps[0])), shell=True).rstrip('\n'), + int(SFT_timestamps[-1]), + subprocess.check_output('lalapps_tconvert {}'.format( + int(SFT_timestamps[-1])), shell=True).rstrip('\n'))) + self.SFTCatalog = SFTCatalog + + def get_list_of_matching_sfts(self): + matches = glob.glob(self.sftfilepath) + if len(matches) > 0: + return matches + else: + raise IOError('No sfts found matching {}'.format( + self.sftfilepath)) + + def init_computefstatistic_single_point(self): + """ Initilisation step of run_computefstatistic for a single point """ + + self.get_SFTCatalog() + + logging.info('Initialising ephems') + ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem) + + logging.info('Initialising FstatInput') + dFreq = 0 + if self.transient: + self.whatToCompute = lalpulsar.FSTATQ_ATOMS_PER_DET + else: + self.whatToCompute = lalpulsar.FSTATQ_2F + + FstatOAs = lalpulsar.FstatOptionalArgs() + FstatOAs.randSeed = lalpulsar.FstatOptionalArgsDefaults.randSeed + FstatOAs.SSBprec = lalpulsar.FstatOptionalArgsDefaults.SSBprec + FstatOAs.Dterms = lalpulsar.FstatOptionalArgsDefaults.Dterms + FstatOAs.runningMedianWindow = lalpulsar.FstatOptionalArgsDefaults.runningMedianWindow + FstatOAs.FstatMethod = lalpulsar.FstatOptionalArgsDefaults.FstatMethod + FstatOAs.InjectSqrtSX = lalpulsar.FstatOptionalArgsDefaults.injectSqrtSX + FstatOAs.assumeSqrtSX = lalpulsar.FstatOptionalArgsDefaults.assumeSqrtSX + FstatOAs.prevInput = lalpulsar.FstatOptionalArgsDefaults.prevInput + FstatOAs.collectTiming = lalpulsar.FstatOptionalArgsDefaults.collectTiming + + if hasattr(self, 'injectSource') and type(self.injectSources) == dict: + logging.info('Injecting source with params: {}'.format( + self.injectSources)) + PPV = lalpulsar.CreatePulsarParamsVector(1) + PP = PPV.data[0] + PP.Amp.h0 = self.injectSources['h0'] + PP.Amp.cosi = self.injectSources['cosi'] + PP.Amp.phi0 = self.injectSources['phi0'] + PP.Amp.psi = self.injectSources['psi'] + PP.Doppler.Alpha = self.injectSources['Alpha'] + PP.Doppler.Delta = self.injectSources['Delta'] + PP.Doppler.fkdot = np.array(self.injectSources['fkdot']) + PP.Doppler.refTime = self.tref + if 't0' not in self.injectSources: + PP.Transient.type = lalpulsar.TRANSIENT_NONE + FstatOAs.injectSources = PPV + else: + FstatOAs.injectSources = lalpulsar.FstatOptionalArgsDefaults.injectSources + + if self.minCoverFreq is None or self.maxCoverFreq is None: + fAs = [d.header.f0 for d in self.SFTCatalog.data] + fBs = [d.header.f0 + (d.numBins-1)*d.header.deltaF + for d in self.SFTCatalog.data] + self.minCoverFreq = np.min(fAs) + 0.5 + self.maxCoverFreq = np.max(fBs) - 0.5 + logging.info('Min/max cover freqs not provided, using ' + '{} and {}, est. from SFTs'.format( + self.minCoverFreq, self.maxCoverFreq)) + + self.FstatInput = lalpulsar.CreateFstatInput(self.SFTCatalog, + self.minCoverFreq, + self.maxCoverFreq, + dFreq, + ephems, + FstatOAs + ) + + logging.info('Initialising PulsarDoplerParams') + PulsarDopplerParams = lalpulsar.PulsarDopplerParams() + PulsarDopplerParams.refTime = self.tref + PulsarDopplerParams.Alpha = 1 + PulsarDopplerParams.Delta = 1 + PulsarDopplerParams.fkdot = np.array([0, 0, 0, 0, 0, 0, 0]) + self.PulsarDopplerParams = PulsarDopplerParams + + logging.info('Initialising FstatResults') + self.FstatResults = lalpulsar.FstatResults() + + if self.BSGL: + if len(self.detector_names) < 2: + raise ValueError("Can't use BSGL with single detector data") + else: + logging.info('Initialising BSGL') + + # Tuning parameters - to be reviewed + numDetectors = 2 + if hasattr(self, 'nsegs'): + p_val_threshold = 1e-6 + Fstar0s = np.linspace(0, 1000, 10000) + p_vals = scipy.special.gammaincc(2*self.nsegs, Fstar0s) + Fstar0 = Fstar0s[np.argmin(np.abs(p_vals - p_val_threshold))] + if Fstar0 == Fstar0s[-1]: + raise ValueError('Max Fstar0 exceeded') + else: + Fstar0 = 15. + logging.info('Using Fstar0 of {:1.2f}'.format(Fstar0)) + oLGX = np.zeros(10) + oLGX[:numDetectors] = 1./numDetectors + self.BSGLSetup = lalpulsar.CreateBSGLSetup(numDetectors, + Fstar0, + oLGX, + True, + 1) + self.twoFX = np.zeros(10) + self.whatToCompute = (self.whatToCompute + + lalpulsar.FSTATQ_2F_PER_DET) + + if self.transient: + logging.info('Initialising transient parameters') + self.windowRange = lalpulsar.transientWindowRange_t() + self.windowRange.type = lalpulsar.TRANSIENT_RECTANGULAR + self.windowRange.t0Band = 0 + self.windowRange.dt0 = 1 + self.windowRange.tauBand = 0 + self.windowRange.dtau = 1 + + def compute_fullycoherent_det_stat_single_point( + self, F0, F1, F2, Alpha, Delta, asini=None, period=None, ecc=None, + tp=None, argp=None): + """ Compute the fully-coherent det. statistic at a single point """ + + return self.run_computefstatistic_single_point( + self.minStartTime, self.maxStartTime, F0, F1, F2, Alpha, Delta, + asini, period, ecc, tp, argp) + + def run_computefstatistic_single_point(self, tstart, tend, F0, F1, + F2, Alpha, Delta, asini=None, + period=None, ecc=None, tp=None, + argp=None): + """ Returns twoF or ln(BSGL) fully-coherently at a single point """ + + self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0]) + self.PulsarDopplerParams.Alpha = Alpha + self.PulsarDopplerParams.Delta = Delta + if self.binary: + self.PulsarDopplerParams.asini = asini + self.PulsarDopplerParams.period = period + self.PulsarDopplerParams.ecc = ecc + self.PulsarDopplerParams.tp = tp + self.PulsarDopplerParams.argp = argp + + lalpulsar.ComputeFstat(self.FstatResults, + self.FstatInput, + self.PulsarDopplerParams, + 1, + self.whatToCompute + ) + + if self.transient is False: + if self.BSGL is False: + return self.FstatResults.twoF[0] + + twoF = np.float(self.FstatResults.twoF[0]) + self.twoFX[0] = self.FstatResults.twoFPerDet(0) + self.twoFX[1] = self.FstatResults.twoFPerDet(1) + log10_BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX, + self.BSGLSetup) + return log10_BSGL/np.log10(np.exp(1)) + + self.windowRange.t0 = int(tstart) # TYPE UINT4 + self.windowRange.tau = int(tend - tstart) # TYPE UINT4 + + FS = lalpulsar.ComputeTransientFstatMap( + self.FstatResults.multiFatoms[0], self.windowRange, False) + + if self.BSGL is False: + return 2*FS.F_mn.data[0][0] + + FstatResults_single = copy.copy(self.FstatResults) + FstatResults_single.lenth = 1 + FstatResults_single.data = self.FstatResults.multiFatoms[0].data[0] + FS0 = lalpulsar.ComputeTransientFstatMap( + FstatResults_single.multiFatoms[0], self.windowRange, False) + FstatResults_single.data = self.FstatResults.multiFatoms[0].data[1] + FS1 = lalpulsar.ComputeTransientFstatMap( + FstatResults_single.multiFatoms[0], self.windowRange, False) + + self.twoFX[0] = 2*FS0.F_mn.data[0][0] + self.twoFX[1] = 2*FS1.F_mn.data[0][0] + log10_BSGL = lalpulsar.ComputeBSGL( + 2*FS.F_mn.data[0][0], self.twoFX, self.BSGLSetup) + + return log10_BSGL/np.log10(np.exp(1)) + + def calculate_twoF_cumulative(self, F0, F1, F2, Alpha, Delta, asini=None, + period=None, ecc=None, tp=None, argp=None, + tstart=None, tend=None, npoints=1000, + minfraction=0.01, maxfraction=1): + """ Calculate the cumulative twoF along the obseration span """ + duration = tend - tstart + tstart = tstart + minfraction*duration + taus = np.linspace(minfraction*duration, maxfraction*duration, npoints) + twoFs = [] + if self.transient is False: + self.transient = True + self.init_computefstatistic_single_point() + for tau in taus: + twoFs.append(self.run_computefstatistic_single_point( + tstart=tstart, tend=tstart+tau, F0=F0, F1=F1, F2=F2, + Alpha=Alpha, Delta=Delta, asini=asini, period=period, ecc=ecc, + tp=tp, argp=argp)) + + return taus, np.array(twoFs) + + def plot_twoF_cumulative(self, label, outdir, ax=None, c='k', savefig=True, + title=None, **kwargs): + + taus, twoFs = self.calculate_twoF_cumulative(**kwargs) + if ax is None: + fig, ax = plt.subplots() + ax.plot(taus/86400., twoFs, label=label, color=c) + ax.set_xlabel(r'Days from $t_{{\rm start}}={:.0f}$'.format( + kwargs['tstart'])) + if self.BSGL: + ax.set_ylabel(r'$\log_{10}(\mathrm{BSGL})_{\rm cumulative}$') + else: + ax.set_ylabel(r'$\widetilde{2\mathcal{F}}_{\rm cumulative}$') + ax.set_xlim(0, taus[-1]/86400) + if title: + ax.set_title(title) + if savefig: + plt.tight_layout() + plt.savefig('{}/{}_twoFcumulative.png'.format(outdir, label)) + return taus, twoFs + else: + return ax + + +class SemiCoherentSearch(BaseSearchClass, ComputeFstat): + """ A semi-coherent search """ + + @helper_functions.initializer + def __init__(self, label, outdir, tref, nsegs=None, sftfilepath=None, + binary=False, BSGL=False, minStartTime=None, + maxStartTime=None, minCoverFreq=None, maxCoverFreq=None, + detector=None, earth_ephem=None, sun_ephem=None, + injectSources=None): + """ + Parameters + ---------- + label, outdir: str + A label and directory to read/write data from/to. + tref, minStartTime, maxStartTime: int + GPS seconds of the reference time, and start and end of the data. + nsegs: int + The (fixed) number of segments + sftfilepath: str + File patern to match SFTs + + For all other parameters, see pyfstat.ComputeFStat. + """ + + self.fs_file_name = "{}/{}_FS.dat".format(self.outdir, self.label) + if self.earth_ephem is None: + self.earth_ephem = self.earth_ephem_default + if self.sun_ephem is None: + self.sun_ephem = self.sun_ephem_default + self.transient = True + self.init_computefstatistic_single_point() + self.init_semicoherent_parameters() + + def init_semicoherent_parameters(self): + logging.info(('Initialising semicoherent parameters from {} to {} in' + ' {} segments').format( + self.minStartTime, self.maxStartTime, self.nsegs)) + self.transient = True + self.whatToCompute = lalpulsar.FSTATQ_2F+lalpulsar.FSTATQ_ATOMS_PER_DET + self.tboundaries = np.linspace(self.minStartTime, self.maxStartTime, + self.nsegs+1) + + def run_semi_coherent_computefstatistic_single_point( + self, F0, F1, F2, Alpha, Delta, asini=None, + period=None, ecc=None, tp=None, argp=None): + """ Returns twoF or ln(BSGL) semi-coherently at a single point """ + + self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0]) + self.PulsarDopplerParams.Alpha = Alpha + self.PulsarDopplerParams.Delta = Delta + if self.binary: + self.PulsarDopplerParams.asini = asini + self.PulsarDopplerParams.period = period + self.PulsarDopplerParams.ecc = ecc + self.PulsarDopplerParams.tp = tp + self.PulsarDopplerParams.argp = argp + + lalpulsar.ComputeFstat(self.FstatResults, + self.FstatInput, + self.PulsarDopplerParams, + 1, + self.whatToCompute + ) + + if self.transient is False: + if self.BSGL is False: + return self.FstatResults.twoF[0] + + twoF = np.float(self.FstatResults.twoF[0]) + self.twoFX[0] = self.FstatResults.twoFPerDet(0) + self.twoFX[1] = self.FstatResults.twoFPerDet(1) + log10_BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX, + self.BSGLSetup) + return log10_BSGL/np.log10(np.exp(1)) + + detStat = 0 + for tstart, tend in zip(self.tboundaries[:-1], self.tboundaries[1:]): + self.windowRange.t0 = int(tstart) # TYPE UINT4 + self.windowRange.tau = int(tend - tstart) # TYPE UINT4 + + FS = lalpulsar.ComputeTransientFstatMap( + self.FstatResults.multiFatoms[0], self.windowRange, False) + + if self.BSGL is False: + detStat += 2*FS.F_mn.data[0][0] + continue + + FstatResults_single = copy.copy(self.FstatResults) + FstatResults_single.lenth = 1 + FstatResults_single.data = self.FstatResults.multiFatoms[0].data[0] + FS0 = lalpulsar.ComputeTransientFstatMap( + FstatResults_single.multiFatoms[0], self.windowRange, False) + FstatResults_single.data = self.FstatResults.multiFatoms[0].data[1] + FS1 = lalpulsar.ComputeTransientFstatMap( + FstatResults_single.multiFatoms[0], self.windowRange, False) + + self.twoFX[0] = 2*FS0.F_mn.data[0][0] + self.twoFX[1] = 2*FS1.F_mn.data[0][0] + log10_BSGL = lalpulsar.ComputeBSGL( + 2*FS.F_mn.data[0][0], self.twoFX, self.BSGLSetup) + + detStat += log10_BSGL/np.log10(np.exp(1)) + + return detStat + + +class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): + """ A semi-coherent glitch search + + This implements a basic `semi-coherent glitch F-stat in which the data + is divided into segments either side of the proposed glitches and the + fully-coherent F-stat in each segment is summed to give the semi-coherent + F-stat + """ + + @helper_functions.initializer + def __init__(self, label, outdir, tref, minStartTime, maxStartTime, + nglitch=0, sftfilepath=None, theta0_idx=0, BSGL=False, + minCoverFreq=None, maxCoverFreq=None, + detector=None, earth_ephem=None, sun_ephem=None): + """ + Parameters + ---------- + label, outdir: str + A label and directory to read/write data from/to. + tref, minStartTime, maxStartTime: int + GPS seconds of the reference time, and start and end of the data. + nglitch: int + The (fixed) number of glitches; this can zero, but occasionally + this causes issue (in which case just use ComputeFstat). + sftfilepath: str + File patern to match SFTs + theta0_idx, int + Index (zero-based) of which segment the theta refers to - uyseful + if providing a tight prior on theta to allow the signal to jump + too theta (and not just from) + + For all other parameters, see pyfstat.ComputeFStat. + """ + + self.fs_file_name = "{}/{}_FS.dat".format(self.outdir, self.label) + if self.earth_ephem is None: + self.earth_ephem = self.earth_ephem_default + if self.sun_ephem is None: + self.sun_ephem = self.sun_ephem_default + self.transient = True + self.binary = False + self.init_computefstatistic_single_point() + + def compute_nglitch_fstat(self, F0, F1, F2, Alpha, Delta, *args): + """ Returns the semi-coherent glitch summed twoF """ + + args = list(args) + tboundaries = ([self.minStartTime] + args[-self.nglitch:] + + [self.maxStartTime]) + delta_F0s = args[-3*self.nglitch:-2*self.nglitch] + delta_F1s = args[-2*self.nglitch:-self.nglitch] + delta_F2 = np.zeros(len(delta_F0s)) + delta_phi = np.zeros(len(delta_F0s)) + theta = [0, F0, F1, F2] + delta_thetas = np.atleast_2d( + np.array([delta_phi, delta_F0s, delta_F1s, delta_F2]).T) + + thetas = self.calculate_thetas(theta, delta_thetas, tboundaries, + theta0_idx=self.theta0_idx) + + twoFSum = 0 + for i, theta_i_at_tref in enumerate(thetas): + ts, te = tboundaries[i], tboundaries[i+1] + + twoFVal = self.run_computefstatistic_single_point( + ts, te, theta_i_at_tref[1], theta_i_at_tref[2], + theta_i_at_tref[3], Alpha, Delta) + twoFSum += twoFVal + + if np.isfinite(twoFSum): + return twoFSum + else: + return -np.inf + + def compute_glitch_fstat_single(self, F0, F1, F2, Alpha, Delta, delta_F0, + delta_F1, tglitch): + """ Returns the semi-coherent glitch summed twoF for nglitch=1 + + Note: OBSOLETE, used only for testing + """ + + theta = [F0, F1, F2] + delta_theta = [delta_F0, delta_F1, 0] + tref = self.tref + + theta_at_glitch = self.shift_coefficients(theta, tglitch - tref) + theta_post_glitch_at_glitch = theta_at_glitch + delta_theta + theta_post_glitch = self.shift_coefficients( + theta_post_glitch_at_glitch, tref - tglitch) + + twoFsegA = self.run_computefstatistic_single_point( + self.minStartTime, tglitch, theta[0], theta[1], theta[2], Alpha, + Delta) + + if tglitch == self.maxStartTime: + return twoFsegA + + twoFsegB = self.run_computefstatistic_single_point( + tglitch, self.maxStartTime, theta_post_glitch[0], + theta_post_glitch[1], theta_post_glitch[2], Alpha, + Delta) + + return twoFsegA + twoFsegB + + +class Writer(BaseSearchClass): + """ Instance object for generating SFTs containing glitch signals """ + @helper_functions.initializer + def __init__(self, label='Test', tstart=700000000, duration=100*86400, + dtglitch=None, delta_phi=0, delta_F0=0, delta_F1=0, + delta_F2=0, tref=None, F0=30, F1=1e-10, F2=0, Alpha=5e-3, + Delta=6e-2, h0=0.1, cosi=0.0, psi=0.0, phi=0, Tsft=1800, + outdir=".", sqrtSX=1, Band=4, detector='H1', + minStartTime=None, maxStartTime=None): + """ + Parameters + ---------- + label: string + a human-readable label to be used in naming the output files + tstart, tend : float + start and end times (in gps seconds) of the total observation span + dtglitch: float + time (in gps seconds) of the glitch after tstart. To create data + without a glitch, set dtglitch=tend-tstart or leave as None + delta_phi, delta_F0, delta_F1: float + instanteneous glitch magnitudes in rad, Hz, and Hz/s respectively + tref: float or None + reference time (default is None, which sets the reference time to + tstart) + F0, F1, F2, Alpha, Delta, h0, cosi, psi, phi: float + frequency, sky-position, and amplitude parameters + Tsft: float + the sft duration + minStartTime, maxStartTime: float + if not None, the total span of data, this can be used to generate + transient signals + + see `lalapps_Makefakedata_v5 --help` for help with the other paramaters + """ + + for d in self.delta_phi, self.delta_F0, self.delta_F1, self.delta_F2: + if np.size(d) == 1: + d = [d] + self.tend = self.tstart + self.duration + if self.minStartTime is None: + self.minStartTime = self.tstart + if self.maxStartTime is None: + self.maxStartTime = self.tend + if self.dtglitch is None or self.dtglitch == self.duration: + self.tbounds = [self.tstart, self.tend] + elif np.size(self.dtglitch) == 1: + self.tbounds = [self.tstart, self.tstart+self.dtglitch, self.tend] + else: + self.tglitch = self.tstart + np.array(self.dtglitch) + self.tbounds = [self.tstart] + list(self.tglitch) + [self.tend] + + if os.path.isdir(self.outdir) is False: + os.makedirs(self.outdir) + if self.tref is None: + self.tref = self.tstart + self.tend = self.tstart + self.duration + tbs = np.array(self.tbounds) + self.durations_days = (tbs[1:] - tbs[:-1]) / 86400 + self.config_file_name = "{}/{}.cff".format(outdir, label) + + self.theta = np.array([phi, F0, F1, F2]) + self.delta_thetas = np.atleast_2d( + np.array([delta_phi, delta_F0, delta_F1, delta_F2]).T) + + self.data_duration = self.maxStartTime - self.minStartTime + numSFTs = int(float(self.data_duration) / self.Tsft) + self.sftfilename = lalpulsar.OfficialSFTFilename( + 'H', '1', numSFTs, self.Tsft, self.minStartTime, + self.data_duration, self.label) + self.sftfilepath = '{}/{}'.format(self.outdir, self.sftfilename) + self.calculate_fmin_Band() + + def make_data(self): + ''' A convienience wrapper to generate a cff file then sfts ''' + self.make_cff() + self.run_makefakedata() + + def get_single_config_line(self, i, Alpha, Delta, h0, cosi, psi, phi, F0, + F1, F2, tref, tstart, duration_days): + template = ( +"""[TS{}] +Alpha = {:1.18e} +Delta = {:1.18e} +h0 = {:1.18e} +cosi = {:1.18e} +psi = {:1.18e} +phi0 = {:1.18e} +Freq = {:1.18e} +f1dot = {:1.18e} +f2dot = {:1.18e} +refTime = {:10.6f} +transientWindowType=rect +transientStartTime={:10.3f} +transientTauDays={:1.3f}\n""") + return template.format(i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, + F2, tref, tstart, duration_days) + + def make_cff(self): + """ + Generates an .cff file for a 'glitching' signal + + """ + + thetas = self.calculate_thetas(self.theta, self.delta_thetas, + self.tbounds) + + content = '' + for i, (t, d, ts) in enumerate(zip(thetas, self.durations_days, + self.tbounds[:-1])): + line = self.get_single_config_line( + i, self.Alpha, self.Delta, self.h0, self.cosi, self.psi, + t[0], t[1], t[2], t[3], self.tref, ts, d) + + content += line + + if self.check_if_cff_file_needs_rewritting(content): + config_file = open(self.config_file_name, "w+") + config_file.write(content) + config_file.close() + + def calculate_fmin_Band(self): + self.fmin = self.F0 - .5 * self.Band + + def check_cached_data_okay_to_use(self, cl): + """ Check if cached data exists and, if it does, if it can be used """ + + getmtime = os.path.getmtime + + if os.path.isfile(self.sftfilepath) is False: + logging.info('No SFT file matching {} found'.format( + self.sftfilepath)) + return False + else: + logging.info('Matching SFT file found') + + if getmtime(self.sftfilepath) < getmtime(self.config_file_name): + logging.info( + ('The config file {} has been modified since the sft file {} ' + + 'was created').format( + self.config_file_name, self.sftfilepath)) + return False + + logging.info( + 'The config file {} is older than the sft file {}'.format( + self.config_file_name, self.sftfilepath)) + logging.info('Checking contents of cff file') + logging.info('Execute: {}'.format( + 'lalapps_SFTdumpheader {} | head -n 20'.format(self.sftfilepath))) + output = subprocess.check_output( + 'lalapps_SFTdumpheader {} | head -n 20'.format(self.sftfilepath), + shell=True) + calls = [line for line in output.split('\n') if line[:3] == 'lal'] + if calls[0] == cl: + logging.info('Contents matched, use old sft file') + return True + else: + logging.info('Contents unmatched, create new sft file') + return False + + def check_if_cff_file_needs_rewritting(self, content): + """ Check if the .cff file has changed + + Returns True if the file should be overwritten - where possible avoid + overwriting to allow cached data to be used + """ + if os.path.isfile(self.config_file_name) is False: + logging.info('No config file {} found'.format( + self.config_file_name)) + return True + else: + logging.info('Config file {} already exists'.format( + self.config_file_name)) + + with open(self.config_file_name, 'r') as f: + file_content = f.read() + if file_content == content: + logging.info( + 'File contents match, no update of {} required'.format( + self.config_file_name)) + return False + else: + logging.info( + 'File contents unmatched, updating {}'.format( + self.config_file_name)) + return True + + def run_makefakedata(self): + """ Generate the sft data from the configuration file """ + + # Remove old data: + try: + os.unlink("{}/*{}*.sft".format(self.outdir, self.label)) + except OSError: + pass + + cl = [] + cl.append('lalapps_Makefakedata_v5') + cl.append('--outSingleSFT=TRUE') + cl.append('--outSFTdir="{}"'.format(self.outdir)) + cl.append('--outLabel="{}"'.format(self.label)) + cl.append('--IFOs="{}"'.format(self.detector)) + cl.append('--sqrtSX="{}"'.format(self.sqrtSX)) + if self.minStartTime is None: + cl.append('--startTime={:10.9f}'.format(float(self.tstart))) + else: + cl.append('--startTime={:10.9f}'.format(float(self.minStartTime))) + if self.maxStartTime is None: + cl.append('--duration={}'.format(int(self.duration))) + else: + data_duration = self.maxStartTime - self.minStartTime + cl.append('--duration={}'.format(int(data_duration))) + cl.append('--fmin={}'.format(int(self.fmin))) + cl.append('--Band={}'.format(self.Band)) + cl.append('--Tsft={}'.format(self.Tsft)) + if self.h0 != 0: + cl.append('--injectionSources="{}"'.format(self.config_file_name)) + + cl = " ".join(cl) + + if self.check_cached_data_okay_to_use(cl) is False: + logging.info("Executing: " + cl) + os.system(cl) + os.system('\n') + + def predict_fstat(self): + """ Wrapper to lalapps_PredictFstat """ + c_l = [] + c_l.append("lalapps_PredictFstat") + c_l.append("--h0={}".format(self.h0)) + c_l.append("--cosi={}".format(self.cosi)) + c_l.append("--psi={}".format(self.psi)) + c_l.append("--Alpha={}".format(self.Alpha)) + c_l.append("--Delta={}".format(self.Delta)) + c_l.append("--Freq={}".format(self.F0)) + + c_l.append("--DataFiles='{}'".format( + self.outdir+"/*SFT_"+self.label+"*sft")) + c_l.append("--assumeSqrtSX={}".format(self.sqrtSX)) + + c_l.append("--minStartTime={}".format(int(self.minStartTime))) + c_l.append("--maxStartTime={}".format(int(self.maxStartTime))) + + logging.info("Executing: " + " ".join(c_l) + "\n") + output = subprocess.check_output(" ".join(c_l), shell=True) + twoF = float(output.split('\n')[-2]) + return float(twoF) diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py new file mode 100644 index 0000000000000000000000000000000000000000..d86364c4ec138b16b4e4295db5cc891a1181a2f2 --- /dev/null +++ b/pyfstat/grid_based_searches.py @@ -0,0 +1,292 @@ +""" Searches using grid-based methods """ + +import os +import logging +import itertools +from collections import OrderedDict + +import numpy as np +import matplotlib +import matplotlib.pyplot as plt + +import helper_functions +from core import BaseSearchClass, ComputeFstat, SemiCoherentGlitchSearch +from core import tqdm, args, earth_ephem, sun_ephem + + +class GridSearch(BaseSearchClass): + """ Gridded search using ComputeFstat """ + @helper_functions.initializer + def __init__(self, label, outdir, sftfilepath, F0s=[0], F1s=[0], F2s=[0], + Alphas=[0], Deltas=[0], tref=None, minStartTime=None, + maxStartTime=None, BSGL=False, minCoverFreq=None, + maxCoverFreq=None, earth_ephem=None, sun_ephem=None, + detector=None): + """ + Parameters + ---------- + label, outdir: str + A label and directory to read/write data from/to + sftfilepath: str + File patern to match SFTs + F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple + Length 3 tuple describing the grid for each parameter, e.g + [F0min, F0max, dF0], for a fixed value simply give [F0]. + tref, minStartTime, maxStartTime: int + GPS seconds of the reference time, start time and end time + + For all other parameters, see `pyfstat.ComputeFStat` for details + """ + + if earth_ephem is None: + self.earth_ephem = self.earth_ephem_default + if sun_ephem is None: + self.sun_ephem = self.sun_ephem_default + + if os.path.isdir(outdir) is False: + os.mkdir(outdir) + self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label) + self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta'] + + def inititate_search_object(self): + logging.info('Setting up search object') + self.search = ComputeFstat( + tref=self.tref, sftfilepath=self.sftfilepath, + minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq, + earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, + detector=self.detector, transient=False, + minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, + BSGL=self.BSGL) + + def get_array_from_tuple(self, x): + if len(x) == 1: + return np.array(x) + else: + return np.arange(x[0], x[1]*(1+1e-15), x[2]) + + def get_input_data_array(self): + arrays = [] + for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s, + self.Alphas, self.Deltas): + arrays.append(self.get_array_from_tuple(tup)) + + input_data = [] + for vals in itertools.product(*arrays): + input_data.append(vals) + + self.arrays = arrays + self.input_data = np.array(input_data) + + def check_old_data_is_okay_to_use(self): + if args.clean: + return False + if os.path.isfile(self.out_file) is False: + logging.info('No old data found, continuing with grid search') + return False + data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' ')) + if np.all(data[:, 0:-1] == self.input_data): + logging.info( + 'Old data found with matching input, no search performed') + return data + else: + logging.info( + 'Old data found, input differs, continuing with grid search') + return False + + def run(self, return_data=False): + self.get_input_data_array() + old_data = self.check_old_data_is_okay_to_use() + if old_data is not False: + self.data = old_data + return + + self.inititate_search_object() + + logging.info('Total number of grid points is {}'.format( + len(self.input_data))) + + data = [] + for vals in tqdm(self.input_data): + FS = self.search.run_computefstatistic_single_point(*vals) + data.append(list(vals) + [FS]) + + data = np.array(data) + if return_data: + return data + else: + logging.info('Saving data to {}'.format(self.out_file)) + np.savetxt(self.out_file, data, delimiter=' ') + self.data = data + + def convert_F0_to_mismatch(self, F0, F0hat, Tseg): + DeltaF0 = F0[1] - F0[0] + m_spacing = (np.pi*Tseg*DeltaF0)**2 / 12. + N = len(F0) + return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing) + + def convert_F1_to_mismatch(self, F1, F1hat, Tseg): + DeltaF1 = F1[1] - F1[0] + m_spacing = (np.pi*Tseg**2*DeltaF1)**2 / 720. + N = len(F1) + return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing) + + def add_mismatch_to_ax(self, ax, x, y, xkey, ykey, xhat, yhat, Tseg): + axX = ax.twiny() + axX.zorder = -10 + axY = ax.twinx() + axY.zorder = -10 + + if xkey == 'F0': + m = self.convert_F0_to_mismatch(x, xhat, Tseg) + axX.set_xlim(m[0], m[-1]) + + if ykey == 'F1': + m = self.convert_F1_to_mismatch(y, yhat, Tseg) + axY.set_ylim(m[0], m[-1]) + + def plot_1D(self, xkey): + fig, ax = plt.subplots() + xidx = self.keys.index(xkey) + x = np.unique(self.data[:, xidx]) + z = self.data[:, -1] + plt.plot(x, z) + fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) + + def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, + add_mismatch=None, xN=None, yN=None, flat_keys=[], + rel_flat_idxs=[], flatten_method=np.max, + predicted_twoF=None, cm=None, cbarkwargs={}): + """ Plots a 2D grid of 2F values + + Parameters + ---------- + add_mismatch: tuple (xhat, yhat, Tseg) + If not None, add a secondary axis with the metric mismatch from the + point xhat, yhat with duration Tseg + flatten_method: np.max + Function to use in flattening the flat_keys + """ + if ax is None: + fig, ax = plt.subplots() + xidx = self.keys.index(xkey) + yidx = self.keys.index(ykey) + flat_idxs = [self.keys.index(k) for k in flat_keys] + + x = np.unique(self.data[:, xidx]) + y = np.unique(self.data[:, yidx]) + flat_vals = [np.unique(self.data[:, j]) for j in flat_idxs] + z = self.data[:, -1] + + Y, X = np.meshgrid(y, x) + shape = [len(x), len(y)] + [len(v) for v in flat_vals] + Z = z.reshape(shape) + + if len(rel_flat_idxs) > 0: + Z = flatten_method(Z, axis=tuple(rel_flat_idxs)) + + if predicted_twoF: + Z = (predicted_twoF - Z) / (predicted_twoF + 4) + if cm is None: + cm = plt.cm.viridis_r + else: + if cm is None: + cm = plt.cm.viridis + + pax = ax.pcolormesh(X, Y, Z, cmap=cm, vmin=vmin, vmax=vmax) + cb = plt.colorbar(pax, ax=ax, **cbarkwargs) + cb.set_label('$2\mathcal{F}$') + + if add_mismatch: + self.add_mismatch_to_ax(ax, x, y, xkey, ykey, *add_mismatch) + + ax.set_xlim(x[0], x[-1]) + ax.set_ylim(y[0], y[-1]) + labels = {'F0': '$f$', 'F1': '$\dot{f}$'} + ax.set_xlabel(labels[xkey]) + ax.set_ylabel(labels[ykey]) + + if xN: + ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(xN)) + if yN: + ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(yN)) + + if save: + fig.tight_layout() + fig.savefig('{}/{}_2D.png'.format(self.outdir, self.label)) + else: + return ax + + def get_max_twoF(self): + twoF = self.data[:, -1] + idx = np.argmax(twoF) + v = self.data[idx, :] + d = OrderedDict(minStartTime=v[0], maxStartTime=v[1], F0=v[2], F1=v[3], + F2=v[4], Alpha=v[5], Delta=v[6], twoF=v[7]) + return d + + def print_max_twoF(self): + d = self.get_max_twoF() + print('Max twoF values for {}:'.format(self.label)) + for k, v in d.iteritems(): + print(' {}={}'.format(k, v)) + + +class GridGlitchSearch(GridSearch): + """ Grid search using the SemiCoherentGlitchSearch """ + @helper_functions.initializer + def __init__(self, label, outdir, sftfilepath=None, F0s=[0], + F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None, + Alphas=[0], Deltas=[0], tref=None, minStartTime=None, + maxStartTime=None, minCoverFreq=None, maxCoverFreq=None, + write_after=1000, earth_ephem=None, sun_ephem=None): + + """ + Parameters + ---------- + label, outdir: str + A label and directory to read/write data from/to + sftfilepath: str + File patern to match SFTs + F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple + Length 3 tuple describing the grid for each parameter, e.g + [F0min, F0max, dF0], for a fixed value simply give [F0]. + tref, minStartTime, maxStartTime: int + GPS seconds of the reference time, start time and end time + + For all other parameters, see pyfstat.ComputeFStat. + """ + if tglitchs is None: + self.tglitchs = [self.maxStartTime] + if earth_ephem is None: + self.earth_ephem = self.earth_ephem_default + if sun_ephem is None: + self.sun_ephem = self.sun_ephem_default + + self.search = SemiCoherentGlitchSearch( + label=label, outdir=outdir, sftfilepath=self.sftfilepath, + tref=tref, minStartTime=minStartTime, maxStartTime=maxStartTime, + minCoverFreq=minCoverFreq, maxCoverFreq=maxCoverFreq, + earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, + BSGL=self.BSGL) + + if os.path.isdir(outdir) is False: + os.mkdir(outdir) + self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label) + self.keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta', 'delta_F0', + 'delta_F1', 'tglitch'] + + def get_input_data_array(self): + arrays = [] + for tup in (self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas, + self.delta_F0s, self.delta_F1s, self.tglitchs): + arrays.append(self.get_array_from_tuple(tup)) + + input_data = [] + for vals in itertools.product(*arrays): + input_data.append(vals) + + self.arrays = arrays + self.input_data = np.array(input_data) + + + diff --git a/pyfstat/helper_functions.py b/pyfstat/helper_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..3034df7fad0d78956797dd14461658176bee23e4 --- /dev/null +++ b/pyfstat/helper_functions.py @@ -0,0 +1,120 @@ +""" +Provides helpful functions to facilitate ease-of-use of pyfstat +""" + +import os +import sys +import argparse +import logging +import inspect +from functools import wraps + +import matplotlib.pyplot as plt +import numpy as np + + +def set_up_optional_tqdm(): + try: + from tqdm import tqdm + except ImportError: + def tqdm(x, *args, **kwargs): + return x + return tqdm + + +def set_up_matplotlib_defaults(): + plt.switch_backend('Agg') + plt.rcParams['text.usetex'] = True + plt.rcParams['axes.formatter.useoffset'] = False + + +def set_up_command_line_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("-q", "--quite", help="Decrease output verbosity", + action="store_true") + parser.add_argument("--no-interactive", help="Don't use interactive", + action="store_true") + parser.add_argument("-c", "--clean", help="Don't use cached data", + action="store_true") + parser.add_argument("-u", "--use-old-data", action="store_true") + parser.add_argument('-s', "--setup-only", action="store_true") + parser.add_argument('-n', "--no-template-counting", action="store_true") + parser.add_argument('unittest_args', nargs='*') + args, unknown = parser.parse_known_args() + sys.argv[1:] = args.unittest_args + if args.quite or args.no_interactive: + def tqdm(x, *args, **kwargs): + return x + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + stream_handler = logging.StreamHandler() + if args.quite: + stream_handler.setLevel(logging.WARNING) + else: + stream_handler.setLevel(logging.DEBUG) + stream_handler.setFormatter(logging.Formatter( + '%(asctime)s %(levelname)-8s: %(message)s', datefmt='%H:%M')) + logger.addHandler(stream_handler) + return args + + +def set_up_ephemeris_configuration(): + config_file = os.path.expanduser('~')+'/.pyfstat.conf' + if os.path.isfile(config_file): + d = {} + with open(config_file, 'r') as f: + for line in f: + k, v = line.split('=') + k = k.replace(' ', '') + for item in [' ', "'", '"', '\n']: + v = v.replace(item, '') + d[k] = v + earth_ephem = d['earth_ephem'] + sun_ephem = d['sun_ephem'] + else: + logging.warning('No ~/.pyfstat.conf file found please provide the ' + 'paths when initialising searches') + earth_ephem = None + sun_ephem = None + return earth_ephem, sun_ephem + + +def round_to_n(x, n): + if not x: + return 0 + power = -int(np.floor(np.log10(abs(x)))) + (n - 1) + factor = (10 ** power) + return round(x * factor) / factor + + +def texify_float(x, d=2): + if type(x) == str: + return x + x = round_to_n(x, d) + if 0.01 < abs(x) < 100: + return str(x) + else: + power = int(np.floor(np.log10(abs(x)))) + stem = np.round(x / 10**power, d) + if d == 1: + stem = int(stem) + return r'${}{{\times}}10^{{{}}}$'.format(stem, power) + + +def initializer(func): + """ Decorator function to automatically assign the parameters to self """ + names, varargs, keywords, defaults = inspect.getargspec(func) + + @wraps(func) + def wrapper(self, *args, **kargs): + for name, arg in list(zip(names[1:], args)) + list(kargs.items()): + setattr(self, name, arg) + + for name, default in zip(reversed(names), reversed(defaults)): + if not hasattr(self, name): + setattr(self, name, default) + + func(self, *args, **kargs) + + return wrapper + diff --git a/pyfstat.py b/pyfstat/mcmc_based_searches.py old mode 100755 new mode 100644 similarity index 56% rename from pyfstat.py rename to pyfstat/mcmc_based_searches.py index c6e798729ccd23f90a261440a6c77b1ffc441fc6..f7ce22c84f313a20ab40dff31221505356180826 --- a/pyfstat.py +++ b/pyfstat/mcmc_based_searches.py @@ -1,908 +1,27 @@ -""" Classes for various types of searches using ComputeFstatistic """ -import os +""" Searches using MCMC-based methods """ + import sys -import itertools -import logging -import argparse +import os import copy -import glob -import inspect -from functools import wraps -import subprocess +import logging from collections import OrderedDict import numpy as np import matplotlib import matplotlib.pyplot as plt -import scipy.special -import scipy.optimize import emcee import corner import dill as pickle -import lal -import lalpulsar - - -def set_up_optional_tqdm(): - try: - from tqdm import tqdm - except ImportError: - def tqdm(x, *args, **kwargs): - return x - - -def set_up_matplotlib_defaults(): - plt.switch_backend('Agg') - plt.rcParams['text.usetex'] = True - plt.rcParams['axes.formatter.useoffset'] = False - - -def set_up_ephemeris_configuration(): - config_file = os.path.expanduser('~')+'/.pyfstat.conf' - if os.path.isfile(config_file): - d = {} - with open(config_file, 'r') as f: - for line in f: - k, v = line.split('=') - k = k.replace(' ', '') - for item in [' ', "'", '"', '\n']: - v = v.replace(item, '') - d[k] = v - earth_ephem = d['earth_ephem'] - sun_ephem = d['sun_ephem'] - else: - logging.warning('No ~/.pyfstat.conf file found please provide the ' - 'paths when initialising searches') - earth_ephem = None - sun_ephem = None - - -def set_up_command_line_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument("-q", "--quite", help="Decrease output verbosity", - action="store_true") - parser.add_argument("--no-interactive", help="Don't use interactive", - action="store_true") - parser.add_argument("-c", "--clean", help="Don't use cached data", - action="store_true") - parser.add_argument("-u", "--use-old-data", action="store_true") - parser.add_argument('-s', "--setup-only", action="store_true") - parser.add_argument('-n', "--no-template-counting", action="store_true") - parser.add_argument('unittest_args', nargs='*') - args, unknown = parser.parse_known_args() - sys.argv[1:] = args.unittest_args - if args.quite or args.no_interactive: - def tqdm(x, *args, **kwargs): - return x - logger = logging.getLogger() - logger.setLevel(logging.DEBUG) - stream_handler = logging.StreamHandler() - if args.quite: - stream_handler.setLevel(logging.WARNING) - else: - stream_handler.setLevel(logging.DEBUG) - stream_handler.setFormatter(logging.Formatter( - '%(asctime)s %(levelname)-8s: %(message)s', datefmt='%H:%M')) - logger.addHandler(stream_handler) - -set_up_optional_tqdm() -set_up_matplotlib_defaults() -set_up_ephemeris_configuration() -set_up_command_line_arguments() - - -def round_to_n(x, n): - if not x: - return 0 - power = -int(np.floor(np.log10(abs(x)))) + (n - 1) - factor = (10 ** power) - return round(x * factor) / factor - - -def texify_float(x, d=2): - if type(x) == str: - return x - x = round_to_n(x, d) - if 0.01 < abs(x) < 100: - return str(x) - else: - power = int(np.floor(np.log10(abs(x)))) - stem = np.round(x / 10**power, d) - if d == 1: - stem = int(stem) - return r'${}{{\times}}10^{{{}}}$'.format(stem, power) - - -def initializer(func): - """ Decorator function to automatically assign the parameters to self """ - names, varargs, keywords, defaults = inspect.getargspec(func) - - @wraps(func) - def wrapper(self, *args, **kargs): - for name, arg in list(zip(names[1:], args)) + list(kargs.items()): - setattr(self, name, arg) - - for name, default in zip(reversed(names), reversed(defaults)): - if not hasattr(self, name): - setattr(self, name, default) - - func(self, *args, **kargs) - - return wrapper - - -def read_par(label, outdir): - """ Read in a .par file, returns a dictionary of the values """ - filename = '{}/{}.par'.format(outdir, label) - d = {} - with open(filename, 'r') as f: - for line in f: - if len(line.split('=')) > 1: - key, val = line.rstrip('\n').split(' = ') - key = key.strip() - d[key] = np.float64(eval(val.rstrip('; '))) - return d - - -def get_optimal_setup( - R, Nsegs0, tref, minStartTime, maxStartTime, DeltaOmega, - DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem): - logging.info('Calculating optimal setup for R={}, Nsegs0={}'.format( - R, Nsegs0)) - - V_0 = get_V_estimate( - Nsegs0, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs, - fiducial_freq, detector_names, earth_ephem, sun_ephem) - logging.info('Stage {}, nsegs={}, V={}'.format(0, Nsegs0, V_0)) - - nsegs_vals = [Nsegs0] - V_vals = [V_0] - - i = 0 - nsegs_i = Nsegs0 - while nsegs_i > 1: - nsegs_i, V_i = get_nsegs_ip1( - nsegs_i, R, tref, minStartTime, maxStartTime, DeltaOmega, - DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem) - nsegs_vals.append(nsegs_i) - V_vals.append(V_i) - i += 1 - logging.info( - 'Stage {}, nsegs={}, V={}'.format(i, nsegs_i, V_i)) - - return nsegs_vals, V_vals - - -def get_nsegs_ip1( - nsegs_i, R, tref, minStartTime, maxStartTime, DeltaOmega, - DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem): - - log10R = np.log10(R) - log10Vi = np.log10(get_V_estimate( - nsegs_i, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs, - fiducial_freq, detector_names, earth_ephem, sun_ephem)) - - def f(nsegs_ip1): - if nsegs_ip1[0] > nsegs_i: - return 1e6 - if nsegs_ip1[0] < 0: - return 1e6 - nsegs_ip1 = int(nsegs_ip1[0]) - if nsegs_ip1 == 0: - nsegs_ip1 = 1 - Vip1 = get_V_estimate( - nsegs_ip1, tref, minStartTime, maxStartTime, DeltaOmega, - DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem) - if Vip1[0] is None: - return 1e6 - else: - log10Vip1 = np.log10(Vip1) - return np.abs(log10Vi[0] + log10R - log10Vip1[0]) - res = scipy.optimize.minimize(f, .5*nsegs_i, method='Powell', tol=0.1, - options={'maxiter': 10}) - nsegs_ip1 = int(res.x) - if nsegs_ip1 == 0: - nsegs_ip1 = 1 - if res.success: - return nsegs_ip1, get_V_estimate( - nsegs_ip1, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs, - fiducial_freq, detector_names, earth_ephem, sun_ephem) - else: - raise ValueError('Optimisation unsuccesful') - - -def get_V_estimate( - nsegs, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs, - fiducial_freq, detector_names, earth_ephem, sun_ephem): - """ Returns V, Vsky, Vpe estimated from the super-sky metric - - Parameters - ---------- - nsegs: int - Number of semi-coherent segments - tref: int - Reference time in GPS seconds - minStartTime, maxStartTime: int - Minimum and maximum SFT timestamps - DeltaOmega: float - Solid angle of the sky-patch - DeltaFs: array - Array of [DeltaF0, DeltaF1, ...], length determines the number of - spin-down terms. - fiducial_freq: float - Fidicual frequency - detector_names: array - Array of detectors to average over - earth_ephem, sun_ephem: st - Paths to the ephemeris files - - """ - spindowns = len(DeltaFs) - 1 - tboundaries = np.linspace(minStartTime, maxStartTime, nsegs+1) - - ref_time = lal.LIGOTimeGPS(tref) - segments = lal.SegListCreate() - for j in range(len(tboundaries)-1): - seg = lal.SegCreate(lal.LIGOTimeGPS(tboundaries[j]), - lal.LIGOTimeGPS(tboundaries[j+1]), - j) - lal.SegListAppend(segments, seg) - detNames = lal.CreateStringVector(*detector_names) - detectors = lalpulsar.MultiLALDetector() - lalpulsar.ParseMultiLALDetector(detectors, detNames) - detector_weights = None - detector_motion = (lalpulsar.DETMOTION_SPIN - + lalpulsar.DETMOTION_ORBIT) - ephemeris = lalpulsar.InitBarycenter(earth_ephem, sun_ephem) - try: - SSkyMetric = lalpulsar.ComputeSuperskyMetrics( - spindowns, ref_time, segments, fiducial_freq, detectors, - detector_weights, detector_motion, ephemeris) - except RuntimeError as e: - logging.debug('Encountered run-time error {}'.format(e)) - return None, None, None - - sqrtdetG_SKY = np.sqrt(np.linalg.det( - SSkyMetric.semi_rssky_metric.data[:2, :2])) - sqrtdetG_PE = np.sqrt(np.linalg.det( - SSkyMetric.semi_rssky_metric.data[2:, 2:])) - - Vsky = .5*sqrtdetG_SKY*DeltaOmega - Vpe = sqrtdetG_PE * np.prod(DeltaFs) - if Vsky == 0: - Vsky = 1 - if Vpe == 0: - Vpe = 1 - return (Vsky * Vpe, Vsky, Vpe) - - -class BaseSearchClass(object): - """ The base search class, provides general functions """ - - earth_ephem_default = earth_ephem - sun_ephem_default = sun_ephem - - def add_log_file(self): - """ Log output to a file, requires class to have outdir and label """ - logfilename = '{}/{}.log'.format(self.outdir, self.label) - fh = logging.FileHandler(logfilename) - fh.setLevel(logging.INFO) - fh.setFormatter(logging.Formatter( - '%(asctime)s %(levelname)-8s: %(message)s', - datefmt='%y-%m-%d %H:%M')) - logging.getLogger().addHandler(fh) - - def shift_matrix(self, n, dT): - """ Generate the shift matrix - - Parameters - ---------- - n: int - The dimension of the shift-matrix to generate - dT: float - The time delta of the shift matrix - - Returns - ------- - m: array (n, n) - The shift matrix - """ - - m = np.zeros((n, n)) - factorial = np.math.factorial - for i in range(n): - for j in range(n): - if i == j: - m[i, j] = 1.0 - elif i > j: - m[i, j] = 0.0 - else: - if i == 0: - m[i, j] = 2*np.pi*float(dT)**(j-i) / factorial(j-i) - else: - m[i, j] = float(dT)**(j-i) / factorial(j-i) - return m - - def shift_coefficients(self, theta, dT): - """ Shift a set of coefficients by dT - - Parameters - ---------- - theta: array-like, shape (n,) - vector of the expansion coefficients to transform starting from the - lowest degree e.g [phi, F0, F1,...]. - dT: float - difference between the two reference times as tref_new - tref_old. - - Returns - ------- - theta_new: array-like shape (n,) - vector of the coefficients as evaluate as the new reference time. - """ - - n = len(theta) - m = self.shift_matrix(n, dT) - return np.dot(m, theta) - - def calculate_thetas(self, theta, delta_thetas, tbounds, theta0_idx=0): - """ Calculates the set of coefficients for the post-glitch signal """ - thetas = [theta] - for i, dt in enumerate(delta_thetas): - if i < theta0_idx: - pre_theta_at_ith_glitch = self.shift_coefficients( - thetas[0], tbounds[i+1] - self.tref) - post_theta_at_ith_glitch = pre_theta_at_ith_glitch - dt - thetas.insert(0, self.shift_coefficients( - post_theta_at_ith_glitch, self.tref - tbounds[i+1])) - - elif i >= theta0_idx: - pre_theta_at_ith_glitch = self.shift_coefficients( - thetas[i], tbounds[i+1] - self.tref) - post_theta_at_ith_glitch = pre_theta_at_ith_glitch + dt - thetas.append(self.shift_coefficients( - post_theta_at_ith_glitch, self.tref - tbounds[i+1])) - return thetas - - def generate_loudest(self): - params = read_par(self.label, self.outdir) - for key in ['Alpha', 'Delta', 'F0', 'F1']: - if key not in params: - params[key] = self.theta_prior[key] - cmd = ('lalapps_ComputeFstatistic_v2 -a {} -d {} -f {} -s {} -D "{}"' - ' --refTime={} --outputLoudest="{}/{}.loudest" ' - '--minStartTime={} --maxStartTime={}').format( - params['Alpha'], params['Delta'], params['F0'], - params['F1'], self.sftfilepath, params['tref'], - self.outdir, self.label, self.minStartTime, - self.maxStartTime) - subprocess.call([cmd], shell=True) - - -class ComputeFstat(object): - """ Base class providing interface to `lalpulsar.ComputeFstat` """ - - earth_ephem_default = earth_ephem - sun_ephem_default = sun_ephem - - @initializer - def __init__(self, tref, sftfilepath=None, minStartTime=None, - maxStartTime=None, binary=False, transient=True, BSGL=False, - detector=None, minCoverFreq=None, maxCoverFreq=None, - earth_ephem=None, sun_ephem=None, injectSources=None - ): - """ - Parameters - ---------- - tref: int - GPS seconds of the reference time. - sftfilepath: str - File patern to match SFTs - minStartTime, maxStartTime: float GPStime - Only use SFTs with timestemps starting from (including, excluding) - this epoch - binary: bool - If true, search of binary parameters. - transient: bool - If true, allow for the Fstat to be computed over a transient range. - BSGL: bool - If true, compute the BSGL rather than the twoF value. - detector: str - Two character reference to the data to use, specify None for no - contraint. - minCoverFreq, maxCoverFreq: float - The min and max cover frequency passed to CreateFstatInput, if - either is None the range of frequencies in the SFT less 1Hz is - used. - earth_ephem, sun_ephem: str - Paths of the two files containing positions of Earth and Sun, - respectively at evenly spaced times, as passed to CreateFstatInput. - If None defaults defined in BaseSearchClass will be used. - - """ - - if earth_ephem is None: - self.earth_ephem = self.earth_ephem_default - if sun_ephem is None: - self.sun_ephem = self.sun_ephem_default - - self.init_computefstatistic_single_point() - - def get_SFTCatalog(self): - if hasattr(self, 'SFTCatalog'): - return - logging.info('Initialising SFTCatalog') - constraints = lalpulsar.SFTConstraints() - if self.detector: - constraints.detector = self.detector - if self.minStartTime: - constraints.minStartTime = lal.LIGOTimeGPS(self.minStartTime) - if self.maxStartTime: - constraints.maxStartTime = lal.LIGOTimeGPS(self.maxStartTime) - - logging.info('Loading data matching pattern {}'.format( - self.sftfilepath)) - SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepath, constraints) - detector_names = list(set([d.header.name for d in SFTCatalog.data])) - self.detector_names = detector_names - SFT_timestamps = [d.header.epoch for d in SFTCatalog.data] - if args.quite is False and args.no_interactive is False: - try: - from bashplotlib.histogram import plot_hist - print('Data timestamps histogram:') - plot_hist(SFT_timestamps, height=5, bincount=50) - except IOError: - pass - if len(detector_names) == 0: - raise ValueError('No data loaded.') - logging.info('Loaded {} data files from detectors {}'.format( - len(SFT_timestamps), detector_names)) - logging.info('Data spans from {} ({}) to {} ({})'.format( - int(SFT_timestamps[0]), - subprocess.check_output('lalapps_tconvert {}'.format( - int(SFT_timestamps[0])), shell=True).rstrip('\n'), - int(SFT_timestamps[-1]), - subprocess.check_output('lalapps_tconvert {}'.format( - int(SFT_timestamps[-1])), shell=True).rstrip('\n'))) - self.SFTCatalog = SFTCatalog - - def init_computefstatistic_single_point(self): - """ Initilisation step of run_computefstatistic for a single point """ - - self.get_SFTCatalog() - - logging.info('Initialising ephems') - ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem) - - logging.info('Initialising FstatInput') - dFreq = 0 - if self.transient: - self.whatToCompute = lalpulsar.FSTATQ_ATOMS_PER_DET - else: - self.whatToCompute = lalpulsar.FSTATQ_2F - - FstatOAs = lalpulsar.FstatOptionalArgs() - FstatOAs.randSeed = lalpulsar.FstatOptionalArgsDefaults.randSeed - FstatOAs.SSBprec = lalpulsar.FstatOptionalArgsDefaults.SSBprec - FstatOAs.Dterms = lalpulsar.FstatOptionalArgsDefaults.Dterms - FstatOAs.runningMedianWindow = lalpulsar.FstatOptionalArgsDefaults.runningMedianWindow - FstatOAs.FstatMethod = lalpulsar.FstatOptionalArgsDefaults.FstatMethod - FstatOAs.InjectSqrtSX = lalpulsar.FstatOptionalArgsDefaults.injectSqrtSX - FstatOAs.assumeSqrtSX = lalpulsar.FstatOptionalArgsDefaults.assumeSqrtSX - FstatOAs.prevInput = lalpulsar.FstatOptionalArgsDefaults.prevInput - FstatOAs.collectTiming = lalpulsar.FstatOptionalArgsDefaults.collectTiming - - if hasattr(self, 'injectSource') and type(self.injectSources) == dict: - logging.info('Injecting source with params: {}'.format( - self.injectSources)) - PPV = lalpulsar.CreatePulsarParamsVector(1) - PP = PPV.data[0] - PP.Amp.h0 = self.injectSources['h0'] - PP.Amp.cosi = self.injectSources['cosi'] - PP.Amp.phi0 = self.injectSources['phi0'] - PP.Amp.psi = self.injectSources['psi'] - PP.Doppler.Alpha = self.injectSources['Alpha'] - PP.Doppler.Delta = self.injectSources['Delta'] - PP.Doppler.fkdot = np.array(self.injectSources['fkdot']) - PP.Doppler.refTime = self.tref - if 't0' not in self.injectSources: - PP.Transient.type = lalpulsar.TRANSIENT_NONE - FstatOAs.injectSources = PPV - else: - FstatOAs.injectSources = lalpulsar.FstatOptionalArgsDefaults.injectSources - - if self.minCoverFreq is None or self.maxCoverFreq is None: - fAs = [d.header.f0 for d in self.SFTCatalog.data] - fBs = [d.header.f0 + (d.numBins-1)*d.header.deltaF - for d in self.SFTCatalog.data] - self.minCoverFreq = np.min(fAs) + 0.5 - self.maxCoverFreq = np.max(fBs) - 0.5 - logging.info('Min/max cover freqs not provided, using ' - '{} and {}, est. from SFTs'.format( - self.minCoverFreq, self.maxCoverFreq)) - - self.FstatInput = lalpulsar.CreateFstatInput(self.SFTCatalog, - self.minCoverFreq, - self.maxCoverFreq, - dFreq, - ephems, - FstatOAs - ) - - logging.info('Initialising PulsarDoplerParams') - PulsarDopplerParams = lalpulsar.PulsarDopplerParams() - PulsarDopplerParams.refTime = self.tref - PulsarDopplerParams.Alpha = 1 - PulsarDopplerParams.Delta = 1 - PulsarDopplerParams.fkdot = np.array([0, 0, 0, 0, 0, 0, 0]) - self.PulsarDopplerParams = PulsarDopplerParams - - logging.info('Initialising FstatResults') - self.FstatResults = lalpulsar.FstatResults() - - if self.BSGL: - if len(self.detector_names) < 2: - raise ValueError("Can't use BSGL with single detector data") - else: - logging.info('Initialising BSGL') - - # Tuning parameters - to be reviewed - numDetectors = 2 - if hasattr(self, 'nsegs'): - p_val_threshold = 1e-6 - Fstar0s = np.linspace(0, 1000, 10000) - p_vals = scipy.special.gammaincc(2*self.nsegs, Fstar0s) - Fstar0 = Fstar0s[np.argmin(np.abs(p_vals - p_val_threshold))] - if Fstar0 == Fstar0s[-1]: - raise ValueError('Max Fstar0 exceeded') - else: - Fstar0 = 15. - logging.info('Using Fstar0 of {:1.2f}'.format(Fstar0)) - oLGX = np.zeros(10) - oLGX[:numDetectors] = 1./numDetectors - self.BSGLSetup = lalpulsar.CreateBSGLSetup(numDetectors, - Fstar0, - oLGX, - True, - 1) - self.twoFX = np.zeros(10) - self.whatToCompute = (self.whatToCompute + - lalpulsar.FSTATQ_2F_PER_DET) - - if self.transient: - logging.info('Initialising transient parameters') - self.windowRange = lalpulsar.transientWindowRange_t() - self.windowRange.type = lalpulsar.TRANSIENT_RECTANGULAR - self.windowRange.t0Band = 0 - self.windowRange.dt0 = 1 - self.windowRange.tauBand = 0 - self.windowRange.dtau = 1 - - def compute_fullycoherent_det_stat_single_point( - self, F0, F1, F2, Alpha, Delta, asini=None, period=None, ecc=None, - tp=None, argp=None): - """ Compute the fully-coherent det. statistic at a single point """ - - return self.run_computefstatistic_single_point( - self.minStartTime, self.maxStartTime, F0, F1, F2, Alpha, Delta, - asini, period, ecc, tp, argp) - - def run_computefstatistic_single_point(self, tstart, tend, F0, F1, - F2, Alpha, Delta, asini=None, - period=None, ecc=None, tp=None, - argp=None): - """ Returns twoF or ln(BSGL) fully-coherently at a single point """ - - self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0]) - self.PulsarDopplerParams.Alpha = Alpha - self.PulsarDopplerParams.Delta = Delta - if self.binary: - self.PulsarDopplerParams.asini = asini - self.PulsarDopplerParams.period = period - self.PulsarDopplerParams.ecc = ecc - self.PulsarDopplerParams.tp = tp - self.PulsarDopplerParams.argp = argp - - lalpulsar.ComputeFstat(self.FstatResults, - self.FstatInput, - self.PulsarDopplerParams, - 1, - self.whatToCompute - ) - - if self.transient is False: - if self.BSGL is False: - return self.FstatResults.twoF[0] - - twoF = np.float(self.FstatResults.twoF[0]) - self.twoFX[0] = self.FstatResults.twoFPerDet(0) - self.twoFX[1] = self.FstatResults.twoFPerDet(1) - log10_BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX, - self.BSGLSetup) - return log10_BSGL/np.log10(np.exp(1)) - - self.windowRange.t0 = int(tstart) # TYPE UINT4 - self.windowRange.tau = int(tend - tstart) # TYPE UINT4 - - FS = lalpulsar.ComputeTransientFstatMap( - self.FstatResults.multiFatoms[0], self.windowRange, False) - - if self.BSGL is False: - return 2*FS.F_mn.data[0][0] - - FstatResults_single = copy.copy(self.FstatResults) - FstatResults_single.lenth = 1 - FstatResults_single.data = self.FstatResults.multiFatoms[0].data[0] - FS0 = lalpulsar.ComputeTransientFstatMap( - FstatResults_single.multiFatoms[0], self.windowRange, False) - FstatResults_single.data = self.FstatResults.multiFatoms[0].data[1] - FS1 = lalpulsar.ComputeTransientFstatMap( - FstatResults_single.multiFatoms[0], self.windowRange, False) - - self.twoFX[0] = 2*FS0.F_mn.data[0][0] - self.twoFX[1] = 2*FS1.F_mn.data[0][0] - log10_BSGL = lalpulsar.ComputeBSGL( - 2*FS.F_mn.data[0][0], self.twoFX, self.BSGLSetup) - - return log10_BSGL/np.log10(np.exp(1)) - - def calculate_twoF_cumulative(self, F0, F1, F2, Alpha, Delta, asini=None, - period=None, ecc=None, tp=None, argp=None, - tstart=None, tend=None, npoints=1000, - minfraction=0.01, maxfraction=1): - """ Calculate the cumulative twoF along the obseration span """ - duration = tend - tstart - tstart = tstart + minfraction*duration - taus = np.linspace(minfraction*duration, maxfraction*duration, npoints) - twoFs = [] - if self.transient is False: - self.transient = True - self.init_computefstatistic_single_point() - for tau in taus: - twoFs.append(self.run_computefstatistic_single_point( - tstart=tstart, tend=tstart+tau, F0=F0, F1=F1, F2=F2, - Alpha=Alpha, Delta=Delta, asini=asini, period=period, ecc=ecc, - tp=tp, argp=argp)) - - return taus, np.array(twoFs) - - def plot_twoF_cumulative(self, label, outdir, ax=None, c='k', savefig=True, - title=None, **kwargs): - - taus, twoFs = self.calculate_twoF_cumulative(**kwargs) - if ax is None: - fig, ax = plt.subplots() - ax.plot(taus/86400., twoFs, label=label, color=c) - ax.set_xlabel(r'Days from $t_{{\rm start}}={:.0f}$'.format( - kwargs['tstart'])) - if self.BSGL: - ax.set_ylabel(r'$\log_{10}(\mathrm{BSGL})_{\rm cumulative}$') - else: - ax.set_ylabel(r'$\widetilde{2\mathcal{F}}_{\rm cumulative}$') - ax.set_xlim(0, taus[-1]/86400) - if title: - ax.set_title(title) - if savefig: - plt.tight_layout() - plt.savefig('{}/{}_twoFcumulative.png'.format(outdir, label)) - return taus, twoFs - else: - return ax - - -class SemiCoherentSearch(BaseSearchClass, ComputeFstat): - """ A semi-coherent search """ - - @initializer - def __init__(self, label, outdir, tref, nsegs=None, sftfilepath=None, - binary=False, BSGL=False, minStartTime=None, - maxStartTime=None, minCoverFreq=None, maxCoverFreq=None, - detector=None, earth_ephem=None, sun_ephem=None, - injectSources=None): - """ - Parameters - ---------- - label, outdir: str - A label and directory to read/write data from/to. - tref, minStartTime, maxStartTime: int - GPS seconds of the reference time, and start and end of the data. - nsegs: int - The (fixed) number of segments - sftfilepath: str - File patern to match SFTs - - For all other parameters, see pyfstat.ComputeFStat. - """ - - self.fs_file_name = "{}/{}_FS.dat".format(self.outdir, self.label) - if self.earth_ephem is None: - self.earth_ephem = self.earth_ephem_default - if self.sun_ephem is None: - self.sun_ephem = self.sun_ephem_default - self.transient = True - self.init_computefstatistic_single_point() - self.init_semicoherent_parameters() - - def init_semicoherent_parameters(self): - logging.info(('Initialising semicoherent parameters from {} to {} in' - ' {} segments').format( - self.minStartTime, self.maxStartTime, self.nsegs)) - self.transient = True - self.whatToCompute = lalpulsar.FSTATQ_2F+lalpulsar.FSTATQ_ATOMS_PER_DET - self.tboundaries = np.linspace(self.minStartTime, self.maxStartTime, - self.nsegs+1) - - def run_semi_coherent_computefstatistic_single_point( - self, F0, F1, F2, Alpha, Delta, asini=None, - period=None, ecc=None, tp=None, argp=None): - """ Returns twoF or ln(BSGL) semi-coherently at a single point """ - - self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0]) - self.PulsarDopplerParams.Alpha = Alpha - self.PulsarDopplerParams.Delta = Delta - if self.binary: - self.PulsarDopplerParams.asini = asini - self.PulsarDopplerParams.period = period - self.PulsarDopplerParams.ecc = ecc - self.PulsarDopplerParams.tp = tp - self.PulsarDopplerParams.argp = argp - - lalpulsar.ComputeFstat(self.FstatResults, - self.FstatInput, - self.PulsarDopplerParams, - 1, - self.whatToCompute - ) - - if self.transient is False: - if self.BSGL is False: - return self.FstatResults.twoF[0] - - twoF = np.float(self.FstatResults.twoF[0]) - self.twoFX[0] = self.FstatResults.twoFPerDet(0) - self.twoFX[1] = self.FstatResults.twoFPerDet(1) - log10_BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX, - self.BSGLSetup) - return log10_BSGL/np.log10(np.exp(1)) - - detStat = 0 - for tstart, tend in zip(self.tboundaries[:-1], self.tboundaries[1:]): - self.windowRange.t0 = int(tstart) # TYPE UINT4 - self.windowRange.tau = int(tend - tstart) # TYPE UINT4 - - FS = lalpulsar.ComputeTransientFstatMap( - self.FstatResults.multiFatoms[0], self.windowRange, False) - - if self.BSGL is False: - detStat += 2*FS.F_mn.data[0][0] - continue - - FstatResults_single = copy.copy(self.FstatResults) - FstatResults_single.lenth = 1 - FstatResults_single.data = self.FstatResults.multiFatoms[0].data[0] - FS0 = lalpulsar.ComputeTransientFstatMap( - FstatResults_single.multiFatoms[0], self.windowRange, False) - FstatResults_single.data = self.FstatResults.multiFatoms[0].data[1] - FS1 = lalpulsar.ComputeTransientFstatMap( - FstatResults_single.multiFatoms[0], self.windowRange, False) - - self.twoFX[0] = 2*FS0.F_mn.data[0][0] - self.twoFX[1] = 2*FS1.F_mn.data[0][0] - log10_BSGL = lalpulsar.ComputeBSGL( - 2*FS.F_mn.data[0][0], self.twoFX, self.BSGLSetup) - - detStat += log10_BSGL/np.log10(np.exp(1)) - - return detStat - - -class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): - """ A semi-coherent glitch search - - This implements a basic `semi-coherent glitch F-stat in which the data - is divided into segments either side of the proposed glitches and the - fully-coherent F-stat in each segment is summed to give the semi-coherent - F-stat - """ - - @initializer - def __init__(self, label, outdir, tref, minStartTime, maxStartTime, - nglitch=0, sftfilepath=None, theta0_idx=0, BSGL=False, - minCoverFreq=None, maxCoverFreq=None, - detector=None, earth_ephem=None, sun_ephem=None): - """ - Parameters - ---------- - label, outdir: str - A label and directory to read/write data from/to. - tref, minStartTime, maxStartTime: int - GPS seconds of the reference time, and start and end of the data. - nglitch: int - The (fixed) number of glitches; this can zero, but occasionally - this causes issue (in which case just use ComputeFstat). - sftfilepath: str - File patern to match SFTs - theta0_idx, int - Index (zero-based) of which segment the theta refers to - uyseful - if providing a tight prior on theta to allow the signal to jump - too theta (and not just from) - - For all other parameters, see pyfstat.ComputeFStat. - """ - - self.fs_file_name = "{}/{}_FS.dat".format(self.outdir, self.label) - if self.earth_ephem is None: - self.earth_ephem = self.earth_ephem_default - if self.sun_ephem is None: - self.sun_ephem = self.sun_ephem_default - self.transient = True - self.binary = False - self.init_computefstatistic_single_point() - - def compute_nglitch_fstat(self, F0, F1, F2, Alpha, Delta, *args): - """ Returns the semi-coherent glitch summed twoF """ - - args = list(args) - tboundaries = ([self.minStartTime] + args[-self.nglitch:] - + [self.maxStartTime]) - delta_F0s = args[-3*self.nglitch:-2*self.nglitch] - delta_F1s = args[-2*self.nglitch:-self.nglitch] - delta_F2 = np.zeros(len(delta_F0s)) - delta_phi = np.zeros(len(delta_F0s)) - theta = [0, F0, F1, F2] - delta_thetas = np.atleast_2d( - np.array([delta_phi, delta_F0s, delta_F1s, delta_F2]).T) - - thetas = self.calculate_thetas(theta, delta_thetas, tboundaries, - theta0_idx=self.theta0_idx) - - twoFSum = 0 - for i, theta_i_at_tref in enumerate(thetas): - ts, te = tboundaries[i], tboundaries[i+1] - - twoFVal = self.run_computefstatistic_single_point( - ts, te, theta_i_at_tref[1], theta_i_at_tref[2], - theta_i_at_tref[3], Alpha, Delta) - twoFSum += twoFVal - - if np.isfinite(twoFSum): - return twoFSum - else: - return -np.inf - - def compute_glitch_fstat_single(self, F0, F1, F2, Alpha, Delta, delta_F0, - delta_F1, tglitch): - """ Returns the semi-coherent glitch summed twoF for nglitch=1 - - Note: OBSOLETE, used only for testing - """ - - theta = [F0, F1, F2] - delta_theta = [delta_F0, delta_F1, 0] - tref = self.tref - - theta_at_glitch = self.shift_coefficients(theta, tglitch - tref) - theta_post_glitch_at_glitch = theta_at_glitch + delta_theta - theta_post_glitch = self.shift_coefficients( - theta_post_glitch_at_glitch, tref - tglitch) - twoFsegA = self.run_computefstatistic_single_point( - self.minStartTime, tglitch, theta[0], theta[1], theta[2], Alpha, - Delta) - - if tglitch == self.maxStartTime: - return twoFsegA - - twoFsegB = self.run_computefstatistic_single_point( - tglitch, self.maxStartTime, theta_post_glitch[0], - theta_post_glitch[1], theta_post_glitch[2], Alpha, - Delta) - - return twoFsegA + twoFsegB +from core import BaseSearchClass, ComputeFstat +from core import tqdm, args, earth_ephem, sun_ephem +from optimal_setup_functions import get_optimal_setup +import helper_functions class MCMCSearch(BaseSearchClass): """ MCMC search using ComputeFstat""" - @initializer + @helper_functions.initializer def __init__(self, label, outdir, sftfilepath, theta_prior, tref, minStartTime, maxStartTime, nsteps=[100, 100], nwalkers=100, ntemps=1, log10temperature_min=-5, @@ -1426,6 +545,7 @@ class MCMCSearch(BaseSearchClass): raise ValueError('subtractions must be of length ndim') with plt.style.context((context)): + plt.rcParams['text.usetex'] = True if fig is None and axes is None: fig = plt.figure(figsize=(4, 3.0*ndim)) ax = fig.add_subplot(ndim+1, 1, 1) @@ -1464,7 +584,6 @@ class MCMCSearch(BaseSearchClass): if symbols: axes[0].set_ylabel(symbols[0], labelpad=labelpad) - if plot_det_stat: if len(axes) == ndim: axes.append(fig.add_subplot(ndim+1, 1, ndim+1)) @@ -1613,14 +732,6 @@ class MCMCSearch(BaseSearchClass): with open(self.pickle_path, "wb") as File: pickle.dump(d, File) - def get_list_of_matching_sfts(self): - matches = glob.glob(self.sftfilepath) - if len(matches) > 0: - return matches - else: - raise IOError('No sfts found matching {}'.format( - self.sftfilepath)) - def get_saved_data(self): with open(self.pickle_path, "r") as File: d = pickle.load(File) @@ -1878,7 +989,7 @@ class MCMCSearch(BaseSearchClass): class MCMCGlitchSearch(MCMCSearch): """ MCMC search using the SemiCoherentGlitchSearch """ - @initializer + @helper_functions.initializer def __init__(self, label, outdir, sftfilepath, theta_prior, tref, minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100], nwalkers=100, ntemps=1, log10temperature_min=-5, @@ -2129,7 +1240,7 @@ class MCMCGlitchSearch(MCMCSearch): class MCMCSemiCoherentSearch(MCMCSearch): """ MCMC search for a signal using the semi-coherent ComputeFstat """ - @initializer + @helper_functions.initializer def __init__(self, label, outdir, sftfilepath, theta_prior, tref, nsegs=None, nsteps=[100, 100, 100], nwalkers=100, binary=False, ntemps=1, log10temperature_min=-5, theta_initial=None, @@ -2367,9 +1478,10 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): else: nsteps = '{},{}'.format(*rs[0]) line = line.format(i, rs[1], '{:1.1f}'.format(Tcoh), - nsteps, texify_float(V), - texify_float(Vsky), - texify_float(Vpe)) + nsteps, + helper_functions.texify_float(V), + helper_functions.texify_float(Vsky), + helper_functions.texify_float(Vpe)) f.write(line) f.write(r'\end{tabular}' + '\n') else: @@ -2391,7 +1503,8 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): else: nsteps = '{},{}'.format(*rs[0]) line = line.format(i, rs[1], '{:1.1f}'.format(Tcoh), - nsteps, texify_float(Vpe)) + nsteps, + helper_functions.texify_float(Vpe)) f.write(line) f.write(r'\end{tabular}' + '\n') @@ -2565,524 +1678,4 @@ class MCMCTransientSearch(MCMCSearch): self.theta_keys = [self.theta_keys[i] for i in idxs] -class GridSearch(BaseSearchClass): - """ Gridded search using ComputeFstat """ - @initializer - def __init__(self, label, outdir, sftfilepath, F0s=[0], F1s=[0], F2s=[0], - Alphas=[0], Deltas=[0], tref=None, minStartTime=None, - maxStartTime=None, BSGL=False, minCoverFreq=None, - maxCoverFreq=None, earth_ephem=None, sun_ephem=None, - detector=None): - """ - Parameters - ---------- - label, outdir: str - A label and directory to read/write data from/to - sftfilepath: str - File patern to match SFTs - F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple - Length 3 tuple describing the grid for each parameter, e.g - [F0min, F0max, dF0], for a fixed value simply give [F0]. - tref, minStartTime, maxStartTime: int - GPS seconds of the reference time, start time and end time - - For all other parameters, see `pyfstat.ComputeFStat` for details - """ - - if earth_ephem is None: - self.earth_ephem = self.earth_ephem_default - if sun_ephem is None: - self.sun_ephem = self.sun_ephem_default - - if os.path.isdir(outdir) is False: - os.mkdir(outdir) - self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label) - self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta'] - - def inititate_search_object(self): - logging.info('Setting up search object') - self.search = ComputeFstat( - tref=self.tref, sftfilepath=self.sftfilepath, - minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq, - earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, - detector=self.detector, transient=False, - minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, - BSGL=self.BSGL) - - def get_array_from_tuple(self, x): - if len(x) == 1: - return np.array(x) - else: - return np.arange(x[0], x[1]*(1+1e-15), x[2]) - - def get_input_data_array(self): - arrays = [] - for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s, - self.Alphas, self.Deltas): - arrays.append(self.get_array_from_tuple(tup)) - - input_data = [] - for vals in itertools.product(*arrays): - input_data.append(vals) - - self.arrays = arrays - self.input_data = np.array(input_data) - - def check_old_data_is_okay_to_use(self): - if args.clean: - return False - if os.path.isfile(self.out_file) is False: - logging.info('No old data found, continuing with grid search') - return False - data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' ')) - if np.all(data[:, 0:-1] == self.input_data): - logging.info( - 'Old data found with matching input, no search performed') - return data - else: - logging.info( - 'Old data found, input differs, continuing with grid search') - return False - - def run(self, return_data=False): - self.get_input_data_array() - old_data = self.check_old_data_is_okay_to_use() - if old_data is not False: - self.data = old_data - return - - self.inititate_search_object() - - logging.info('Total number of grid points is {}'.format( - len(self.input_data))) - - data = [] - for vals in tqdm(self.input_data): - FS = self.search.run_computefstatistic_single_point(*vals) - data.append(list(vals) + [FS]) - - data = np.array(data) - if return_data: - return data - else: - logging.info('Saving data to {}'.format(self.out_file)) - np.savetxt(self.out_file, data, delimiter=' ') - self.data = data - - def convert_F0_to_mismatch(self, F0, F0hat, Tseg): - DeltaF0 = F0[1] - F0[0] - m_spacing = (np.pi*Tseg*DeltaF0)**2 / 12. - N = len(F0) - return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing) - - def convert_F1_to_mismatch(self, F1, F1hat, Tseg): - DeltaF1 = F1[1] - F1[0] - m_spacing = (np.pi*Tseg**2*DeltaF1)**2 / 720. - N = len(F1) - return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing) - - def add_mismatch_to_ax(self, ax, x, y, xkey, ykey, xhat, yhat, Tseg): - axX = ax.twiny() - axX.zorder = -10 - axY = ax.twinx() - axY.zorder = -10 - - if xkey == 'F0': - m = self.convert_F0_to_mismatch(x, xhat, Tseg) - axX.set_xlim(m[0], m[-1]) - - if ykey == 'F1': - m = self.convert_F1_to_mismatch(y, yhat, Tseg) - axY.set_ylim(m[0], m[-1]) - - def plot_1D(self, xkey): - fig, ax = plt.subplots() - xidx = self.keys.index(xkey) - x = np.unique(self.data[:, xidx]) - z = self.data[:, -1] - plt.plot(x, z) - fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) - - def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, - add_mismatch=None, xN=None, yN=None, flat_keys=[], - rel_flat_idxs=[], flatten_method=np.max, - predicted_twoF=None, cm=None, cbarkwargs={}): - """ Plots a 2D grid of 2F values - - Parameters - ---------- - add_mismatch: tuple (xhat, yhat, Tseg) - If not None, add a secondary axis with the metric mismatch from the - point xhat, yhat with duration Tseg - flatten_method: np.max - Function to use in flattening the flat_keys - """ - if ax is None: - fig, ax = plt.subplots() - xidx = self.keys.index(xkey) - yidx = self.keys.index(ykey) - flat_idxs = [self.keys.index(k) for k in flat_keys] - - x = np.unique(self.data[:, xidx]) - y = np.unique(self.data[:, yidx]) - flat_vals = [np.unique(self.data[:, j]) for j in flat_idxs] - z = self.data[:, -1] - - Y, X = np.meshgrid(y, x) - shape = [len(x), len(y)] + [len(v) for v in flat_vals] - Z = z.reshape(shape) - - if len(rel_flat_idxs) > 0: - Z = flatten_method(Z, axis=tuple(rel_flat_idxs)) - - if predicted_twoF: - Z = (predicted_twoF - Z) / (predicted_twoF + 4) - if cm is None: - cm = plt.cm.viridis_r - else: - if cm is None: - cm = plt.cm.viridis - - pax = ax.pcolormesh(X, Y, Z, cmap=cm, vmin=vmin, vmax=vmax) - cb = plt.colorbar(pax, ax=ax, **cbarkwargs) - cb.set_label('$2\mathcal{F}$') - - if add_mismatch: - self.add_mismatch_to_ax(ax, x, y, xkey, ykey, *add_mismatch) - - ax.set_xlim(x[0], x[-1]) - ax.set_ylim(y[0], y[-1]) - labels = {'F0': '$f$', 'F1': '$\dot{f}$'} - ax.set_xlabel(labels[xkey]) - ax.set_ylabel(labels[ykey]) - - if xN: - ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(xN)) - if yN: - ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(yN)) - - if save: - fig.tight_layout() - fig.savefig('{}/{}_2D.png'.format(self.outdir, self.label)) - else: - return ax - - def get_max_twoF(self): - twoF = self.data[:, -1] - idx = np.argmax(twoF) - v = self.data[idx, :] - d = OrderedDict(minStartTime=v[0], maxStartTime=v[1], F0=v[2], F1=v[3], - F2=v[4], Alpha=v[5], Delta=v[6], twoF=v[7]) - return d - - def print_max_twoF(self): - d = self.get_max_twoF() - print('Max twoF values for {}:'.format(self.label)) - for k, v in d.iteritems(): - print(' {}={}'.format(k, v)) - - -class GridGlitchSearch(GridSearch): - """ Grid search using the SemiCoherentGlitchSearch """ - @initializer - def __init__(self, label, outdir, sftfilepath=None, F0s=[0], - F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None, - Alphas=[0], Deltas=[0], tref=None, minStartTime=None, - maxStartTime=None, minCoverFreq=None, maxCoverFreq=None, - write_after=1000, earth_ephem=None, sun_ephem=None): - - """ - Parameters - ---------- - label, outdir: str - A label and directory to read/write data from/to - sftfilepath: str - File patern to match SFTs - F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple - Length 3 tuple describing the grid for each parameter, e.g - [F0min, F0max, dF0], for a fixed value simply give [F0]. - tref, minStartTime, maxStartTime: int - GPS seconds of the reference time, start time and end time - - For all other parameters, see pyfstat.ComputeFStat. - """ - if tglitchs is None: - self.tglitchs = [self.maxStartTime] - if earth_ephem is None: - self.earth_ephem = self.earth_ephem_default - if sun_ephem is None: - self.sun_ephem = self.sun_ephem_default - - self.search = SemiCoherentGlitchSearch( - label=label, outdir=outdir, sftfilepath=self.sftfilepath, - tref=tref, minStartTime=minStartTime, maxStartTime=maxStartTime, - minCoverFreq=minCoverFreq, maxCoverFreq=maxCoverFreq, - earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, - BSGL=self.BSGL) - - if os.path.isdir(outdir) is False: - os.mkdir(outdir) - self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label) - self.keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta', 'delta_F0', - 'delta_F1', 'tglitch'] - - def get_input_data_array(self): - arrays = [] - for tup in (self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas, - self.delta_F0s, self.delta_F1s, self.tglitchs): - arrays.append(self.get_array_from_tuple(tup)) - - input_data = [] - for vals in itertools.product(*arrays): - input_data.append(vals) - - self.arrays = arrays - self.input_data = np.array(input_data) - - -class Writer(BaseSearchClass): - """ Instance object for generating SFTs containing glitch signals """ - @initializer - def __init__(self, label='Test', tstart=700000000, duration=100*86400, - dtglitch=None, delta_phi=0, delta_F0=0, delta_F1=0, - delta_F2=0, tref=None, F0=30, F1=1e-10, F2=0, Alpha=5e-3, - Delta=6e-2, h0=0.1, cosi=0.0, psi=0.0, phi=0, Tsft=1800, - outdir=".", sqrtSX=1, Band=4, detector='H1', - minStartTime=None, maxStartTime=None): - """ - Parameters - ---------- - label: string - a human-readable label to be used in naming the output files - tstart, tend : float - start and end times (in gps seconds) of the total observation span - dtglitch: float - time (in gps seconds) of the glitch after tstart. To create data - without a glitch, set dtglitch=tend-tstart or leave as None - delta_phi, delta_F0, delta_F1: float - instanteneous glitch magnitudes in rad, Hz, and Hz/s respectively - tref: float or None - reference time (default is None, which sets the reference time to - tstart) - F0, F1, F2, Alpha, Delta, h0, cosi, psi, phi: float - frequency, sky-position, and amplitude parameters - Tsft: float - the sft duration - minStartTime, maxStartTime: float - if not None, the total span of data, this can be used to generate - transient signals - - see `lalapps_Makefakedata_v5 --help` for help with the other paramaters - """ - - for d in self.delta_phi, self.delta_F0, self.delta_F1, self.delta_F2: - if np.size(d) == 1: - d = [d] - self.tend = self.tstart + self.duration - if self.minStartTime is None: - self.minStartTime = self.tstart - if self.maxStartTime is None: - self.maxStartTime = self.tend - if self.dtglitch is None or self.dtglitch == self.duration: - self.tbounds = [self.tstart, self.tend] - elif np.size(self.dtglitch) == 1: - self.tbounds = [self.tstart, self.tstart+self.dtglitch, self.tend] - else: - self.tglitch = self.tstart + np.array(self.dtglitch) - self.tbounds = [self.tstart] + list(self.tglitch) + [self.tend] - - if os.path.isdir(self.outdir) is False: - os.makedirs(self.outdir) - if self.tref is None: - self.tref = self.tstart - self.tend = self.tstart + self.duration - tbs = np.array(self.tbounds) - self.durations_days = (tbs[1:] - tbs[:-1]) / 86400 - self.config_file_name = "{}/{}.cff".format(outdir, label) - - self.theta = np.array([phi, F0, F1, F2]) - self.delta_thetas = np.atleast_2d( - np.array([delta_phi, delta_F0, delta_F1, delta_F2]).T) - - self.data_duration = self.maxStartTime - self.minStartTime - numSFTs = int(float(self.data_duration) / self.Tsft) - self.sftfilename = lalpulsar.OfficialSFTFilename( - 'H', '1', numSFTs, self.Tsft, self.minStartTime, - self.data_duration, self.label) - self.sftfilepath = '{}/{}'.format(self.outdir, self.sftfilename) - self.calculate_fmin_Band() - - def make_data(self): - ''' A convienience wrapper to generate a cff file then sfts ''' - self.make_cff() - self.run_makefakedata() - - def get_single_config_line(self, i, Alpha, Delta, h0, cosi, psi, phi, F0, - F1, F2, tref, tstart, duration_days): - template = ( -"""[TS{}] -Alpha = {:1.18e} -Delta = {:1.18e} -h0 = {:1.18e} -cosi = {:1.18e} -psi = {:1.18e} -phi0 = {:1.18e} -Freq = {:1.18e} -f1dot = {:1.18e} -f2dot = {:1.18e} -refTime = {:10.6f} -transientWindowType=rect -transientStartTime={:10.3f} -transientTauDays={:1.3f}\n""") - return template.format(i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, - F2, tref, tstart, duration_days) - - def make_cff(self): - """ - Generates an .cff file for a 'glitching' signal - - """ - - thetas = self.calculate_thetas(self.theta, self.delta_thetas, - self.tbounds) - - content = '' - for i, (t, d, ts) in enumerate(zip(thetas, self.durations_days, - self.tbounds[:-1])): - line = self.get_single_config_line( - i, self.Alpha, self.Delta, self.h0, self.cosi, self.psi, - t[0], t[1], t[2], t[3], self.tref, ts, d) - - content += line - - if self.check_if_cff_file_needs_rewritting(content): - config_file = open(self.config_file_name, "w+") - config_file.write(content) - config_file.close() - def calculate_fmin_Band(self): - self.fmin = self.F0 - .5 * self.Band - - def check_cached_data_okay_to_use(self, cl): - """ Check if cached data exists and, if it does, if it can be used """ - - getmtime = os.path.getmtime - - if os.path.isfile(self.sftfilepath) is False: - logging.info('No SFT file matching {} found'.format( - self.sftfilepath)) - return False - else: - logging.info('Matching SFT file found') - - if getmtime(self.sftfilepath) < getmtime(self.config_file_name): - logging.info( - ('The config file {} has been modified since the sft file {} ' - + 'was created').format( - self.config_file_name, self.sftfilepath)) - return False - - logging.info( - 'The config file {} is older than the sft file {}'.format( - self.config_file_name, self.sftfilepath)) - logging.info('Checking contents of cff file') - logging.info('Execute: {}'.format( - 'lalapps_SFTdumpheader {} | head -n 20'.format(self.sftfilepath))) - output = subprocess.check_output( - 'lalapps_SFTdumpheader {} | head -n 20'.format(self.sftfilepath), - shell=True) - calls = [line for line in output.split('\n') if line[:3] == 'lal'] - if calls[0] == cl: - logging.info('Contents matched, use old sft file') - return True - else: - logging.info('Contents unmatched, create new sft file') - return False - - def check_if_cff_file_needs_rewritting(self, content): - """ Check if the .cff file has changed - - Returns True if the file should be overwritten - where possible avoid - overwriting to allow cached data to be used - """ - if os.path.isfile(self.config_file_name) is False: - logging.info('No config file {} found'.format( - self.config_file_name)) - return True - else: - logging.info('Config file {} already exists'.format( - self.config_file_name)) - - with open(self.config_file_name, 'r') as f: - file_content = f.read() - if file_content == content: - logging.info( - 'File contents match, no update of {} required'.format( - self.config_file_name)) - return False - else: - logging.info( - 'File contents unmatched, updating {}'.format( - self.config_file_name)) - return True - - def run_makefakedata(self): - """ Generate the sft data from the configuration file """ - - # Remove old data: - try: - os.unlink("{}/*{}*.sft".format(self.outdir, self.label)) - except OSError: - pass - - cl = [] - cl.append('lalapps_Makefakedata_v5') - cl.append('--outSingleSFT=TRUE') - cl.append('--outSFTdir="{}"'.format(self.outdir)) - cl.append('--outLabel="{}"'.format(self.label)) - cl.append('--IFOs="{}"'.format(self.detector)) - cl.append('--sqrtSX="{}"'.format(self.sqrtSX)) - if self.minStartTime is None: - cl.append('--startTime={:10.9f}'.format(float(self.tstart))) - else: - cl.append('--startTime={:10.9f}'.format(float(self.minStartTime))) - if self.maxStartTime is None: - cl.append('--duration={}'.format(int(self.duration))) - else: - data_duration = self.maxStartTime - self.minStartTime - cl.append('--duration={}'.format(int(data_duration))) - cl.append('--fmin={}'.format(int(self.fmin))) - cl.append('--Band={}'.format(self.Band)) - cl.append('--Tsft={}'.format(self.Tsft)) - if self.h0 != 0: - cl.append('--injectionSources="{}"'.format(self.config_file_name)) - - cl = " ".join(cl) - - if self.check_cached_data_okay_to_use(cl) is False: - logging.info("Executing: " + cl) - os.system(cl) - os.system('\n') - - def predict_fstat(self): - """ Wrapper to lalapps_PredictFstat """ - c_l = [] - c_l.append("lalapps_PredictFstat") - c_l.append("--h0={}".format(self.h0)) - c_l.append("--cosi={}".format(self.cosi)) - c_l.append("--psi={}".format(self.psi)) - c_l.append("--Alpha={}".format(self.Alpha)) - c_l.append("--Delta={}".format(self.Delta)) - c_l.append("--Freq={}".format(self.F0)) - - c_l.append("--DataFiles='{}'".format( - self.outdir+"/*SFT_"+self.label+"*sft")) - c_l.append("--assumeSqrtSX={}".format(self.sqrtSX)) - - c_l.append("--minStartTime={}".format(int(self.minStartTime))) - c_l.append("--maxStartTime={}".format(int(self.maxStartTime))) - - logging.info("Executing: " + " ".join(c_l) + "\n") - output = subprocess.check_output(" ".join(c_l), shell=True) - twoF = float(output.split('\n')[-2]) - return float(twoF) diff --git a/pyfstat/optimal_setup_functions.py b/pyfstat/optimal_setup_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1c98f0d29e35c8405dbfbbe9a66a70379acd2b0e --- /dev/null +++ b/pyfstat/optimal_setup_functions.py @@ -0,0 +1,144 @@ +""" + +Provides functions to aid in calculating the optimal setup based on the metric +volume estimates. + +""" + +import logging +import numpy as np +import scipy.optimize +import lal +import lalpulsar + + +def get_optimal_setup( + R, Nsegs0, tref, minStartTime, maxStartTime, DeltaOmega, + DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem): + logging.info('Calculating optimal setup for R={}, Nsegs0={}'.format( + R, Nsegs0)) + + V_0 = get_V_estimate( + Nsegs0, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs, + fiducial_freq, detector_names, earth_ephem, sun_ephem) + logging.info('Stage {}, nsegs={}, V={}'.format(0, Nsegs0, V_0)) + + nsegs_vals = [Nsegs0] + V_vals = [V_0] + + i = 0 + nsegs_i = Nsegs0 + while nsegs_i > 1: + nsegs_i, V_i = get_nsegs_ip1( + nsegs_i, R, tref, minStartTime, maxStartTime, DeltaOmega, + DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem) + nsegs_vals.append(nsegs_i) + V_vals.append(V_i) + i += 1 + logging.info( + 'Stage {}, nsegs={}, V={}'.format(i, nsegs_i, V_i)) + + return nsegs_vals, V_vals + + +def get_nsegs_ip1( + nsegs_i, R, tref, minStartTime, maxStartTime, DeltaOmega, + DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem): + + log10R = np.log10(R) + log10Vi = np.log10(get_V_estimate( + nsegs_i, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs, + fiducial_freq, detector_names, earth_ephem, sun_ephem)) + + def f(nsegs_ip1): + if nsegs_ip1[0] > nsegs_i: + return 1e6 + if nsegs_ip1[0] < 0: + return 1e6 + nsegs_ip1 = int(nsegs_ip1[0]) + if nsegs_ip1 == 0: + nsegs_ip1 = 1 + Vip1 = get_V_estimate( + nsegs_ip1, tref, minStartTime, maxStartTime, DeltaOmega, + DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem) + if Vip1[0] is None: + return 1e6 + else: + log10Vip1 = np.log10(Vip1) + return np.abs(log10Vi[0] + log10R - log10Vip1[0]) + res = scipy.optimize.minimize(f, .5*nsegs_i, method='Powell', tol=0.1, + options={'maxiter': 10}) + nsegs_ip1 = int(res.x) + if nsegs_ip1 == 0: + nsegs_ip1 = 1 + if res.success: + return nsegs_ip1, get_V_estimate( + nsegs_ip1, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs, + fiducial_freq, detector_names, earth_ephem, sun_ephem) + else: + raise ValueError('Optimisation unsuccesful') + + +def get_V_estimate( + nsegs, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs, + fiducial_freq, detector_names, earth_ephem, sun_ephem): + """ Returns V, Vsky, Vpe estimated from the super-sky metric + + Parameters + ---------- + nsegs: int + Number of semi-coherent segments + tref: int + Reference time in GPS seconds + minStartTime, maxStartTime: int + Minimum and maximum SFT timestamps + DeltaOmega: float + Solid angle of the sky-patch + DeltaFs: array + Array of [DeltaF0, DeltaF1, ...], length determines the number of + spin-down terms. + fiducial_freq: float + Fidicual frequency + detector_names: array + Array of detectors to average over + earth_ephem, sun_ephem: st + Paths to the ephemeris files + + """ + spindowns = len(DeltaFs) - 1 + tboundaries = np.linspace(minStartTime, maxStartTime, nsegs+1) + + ref_time = lal.LIGOTimeGPS(tref) + segments = lal.SegListCreate() + for j in range(len(tboundaries)-1): + seg = lal.SegCreate(lal.LIGOTimeGPS(tboundaries[j]), + lal.LIGOTimeGPS(tboundaries[j+1]), + j) + lal.SegListAppend(segments, seg) + detNames = lal.CreateStringVector(*detector_names) + detectors = lalpulsar.MultiLALDetector() + lalpulsar.ParseMultiLALDetector(detectors, detNames) + detector_weights = None + detector_motion = (lalpulsar.DETMOTION_SPIN + + lalpulsar.DETMOTION_ORBIT) + ephemeris = lalpulsar.InitBarycenter(earth_ephem, sun_ephem) + try: + SSkyMetric = lalpulsar.ComputeSuperskyMetrics( + spindowns, ref_time, segments, fiducial_freq, detectors, + detector_weights, detector_motion, ephemeris) + except RuntimeError as e: + logging.debug('Encountered run-time error {}'.format(e)) + return None, None, None + + sqrtdetG_SKY = np.sqrt(np.linalg.det( + SSkyMetric.semi_rssky_metric.data[:2, :2])) + sqrtdetG_PE = np.sqrt(np.linalg.det( + SSkyMetric.semi_rssky_metric.data[2:, 2:])) + + Vsky = .5*sqrtdetG_SKY*DeltaOmega + Vpe = sqrtdetG_PE * np.prod(DeltaFs) + if Vsky == 0: + Vsky = 1 + if Vpe == 0: + Vpe = 1 + return (Vsky * Vpe, Vsky, Vpe) diff --git a/setup.py b/setup.py index e95747a99e9c2ed01b6b0293745f68c7fa5bdc03..a9d69f304d33b723a3894c2f4ab88f2adfc2f737 100644 --- a/setup.py +++ b/setup.py @@ -3,8 +3,8 @@ from distutils.core import setup setup(name='PyFstat', - version='0.1', + version='0.2', author='Gregory Ashton', author_email='gregory.ashton@ligo.org', - py_modules=['pyfstat'], + packages=['pyfstat'], ) diff --git a/tests.py b/tests.py index c0beb6164bbecfd50e272068a8d1d169608e9147..7df371873b01ed98ccb78c8962337769431f3a7d 100644 --- a/tests.py +++ b/tests.py @@ -208,7 +208,7 @@ class TestAuxillaryFunctions(Test): def test_get_V_estimate_sky_F0_F1(self): - out = pyfstat.get_V_estimate( + out = pyfstat.optimal_setup_functions.get_V_estimate( self.nsegs, self.tref, self.minStartTime, self.maxStartTime, self.DeltaOmega, self.DeltaFs, self.fiducial_freq, self.detector_names, self.earth_ephem, self.sun_ephem) @@ -217,7 +217,7 @@ class TestAuxillaryFunctions(Test): self.__class__.Vpe_COMPUTED_WITH_SKY = Vpe def test_get_V_estimate_F0_F1(self): - out = pyfstat.get_V_estimate( + out = pyfstat.optimal_setup_functions.get_V_estimate( self.nsegs, self.tref, self.minStartTime, self.maxStartTime, self.DeltaOmega, self.DeltaFs, self.fiducial_freq, self.detector_names, self.earth_ephem, self.sun_ephem)