diff --git a/pyfstat.py b/pyfstat.py index d1c7186c3ca207751eb14f5acca2291e31849d15..048ef559fd3ea70f3bcf9ecb9e19224d8d8fcc31 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -8,6 +8,7 @@ import copy import glob import inspect from functools import wraps +import subprocess import numpy as np import matplotlib @@ -25,7 +26,7 @@ if os.path.isfile(config_file): for line in f: k, v = line.split('=') k = k.replace(' ', '') - v = v.replace(' ', '') + v = v.replace(' ', '').replace("'", "").replace('"', '').replace('\n', '') d[k] = v earth_ephem = d['earth_ephem'] sun_ephem = d['sun_ephem'] @@ -37,6 +38,24 @@ else: plt.style.use('paper') +parser = argparse.ArgumentParser() +parser.add_argument("-q", "--quite", help="Decrease output verbosity", + action="store_true") +parser.add_argument("-c", "--clean", help="Don't use cached data", + action="store_true") +parser.add_argument('unittest_args', nargs='*') +args, unknown = parser.parse_known_args() +sys.argv[1:] = args.unittest_args + +if args.quite: + log_level = logging.WARNING +else: + log_level = logging.DEBUG + +logging.basicConfig(level=log_level, + format='%(asctime)s %(levelname)-8s: %(message)s', + datefmt='%H:%M') + def initializer(func): """ Automatically assigns the parameters to self""" @@ -65,24 +84,6 @@ def read_par(label, outdir): d[key] = np.float64(val) return d -parser = argparse.ArgumentParser() -parser.add_argument("-q", "--quite", help="Decrease output verbosity", - action="store_true") -parser.add_argument("-c", "--clean", help="Don't use cached data", - action="store_true") -parser.add_argument('unittest_args', nargs='*') -args, unknown = parser.parse_known_args() -sys.argv[1:] = args.unittest_args - -if args.quite: - log_level = logging.WARNING -else: - log_level = logging.DEBUG - -logging.basicConfig(level=log_level, - format='%(asctime)s %(levelname)-8s: %(message)s', - datefmt='%H:%M') - class BaseSearchClass(object): @@ -419,7 +420,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass): @initializer def __init__(self, label, outdir, tref, tstart, tend, sftlabel=None, - nglitch=0, sftdir=None, minCoverFreq=29, maxCoverFreq=31, + nglitch=0, sftdir=None, minCoverFreq=None, maxCoverFreq=None, detector=None, earth_ephem=None, sun_ephem=None): """ Parameters @@ -432,7 +433,9 @@ class SemiCoherentGlitchSearch(BaseSearchClass): tref: int GPS seconds of the reference time minCoverFreq, maxCoverFreq: float - The min and max cover frequency passed to CreateFstatInput + The min and max cover frequency passed to CreateFstatInput, if + either is None the range of frequencies in the SFT less 1Hz is + used. detector: str Two character reference to the data to use, specify None for no contraint @@ -531,6 +534,14 @@ class SemiCoherentGlitchSearch(BaseSearchClass): dFreq = 0 self.whatToCompute = lalpulsar.FSTATQ_ATOMS_PER_DET FstatOptionalArgs = lalpulsar.FstatOptionalArgsDefaults + + if self.minCoverFreq is None or self.maxCoverFreq is None: + fA = SFTCatalog.data[0].header.f0 + numBins = SFTCatalog.data[0].numBins + fB = fA + (numBins-1)*SFTCatalog.data[0].header.deltaF + self.minCoverFreq = fA + 0.5 + self.maxCoverFreq = fB - 0.5 + self.FstatInput = lalpulsar.CreateFstatInput(SFTCatalog, self.minCoverFreq, self.maxCoverFreq, @@ -1323,3 +1334,235 @@ class GridGlitchSearch(BaseSearchClass): twoF = self.data[:, -1] return np.max(twoF) + +class Writer(BaseSearchClass): + """ Instance object for generating SFTs containing glitch signals """ + @initializer + def __init__(self, label='Test', tstart=700000000, duration=100*86400, + dtglitch=None, + delta_phi=0, delta_F0=0, delta_F1=0, delta_F2=0, + tref=None, phi=0, F0=30, F1=1e-10, F2=0, Alpha=5e-3, + Delta=6e-2, h0=0.1, cosi=0.0, psi=0.0, Tsft=1800, outdir=".", + sqrtSX=1, Band=4): + """ + Parameters + ---------- + label: string + a human-readable label to be used in naming the output files + tstart, tend : float + start and end times (in gps seconds) of the total observation span + dtglitch: float + time (in gps seconds) of the glitch after tstart. To create data + without a glitch, set dtglitch=tend-tstart or leave as None + delta_phi, delta_F0, delta_F1: float + instanteneous glitch magnitudes in rad, Hz, and Hz/s respectively + tref: float or None + reference time (default is None, which sets the reference time to + tstart) + phil, F0, F1, F2, Alpha, Delta, h0, cosi, psi: float + pre-glitch phase, frequency, sky-position, and signal properties + Tsft: float + the sft duration + + see `lalapps_Makefakedata_v5 --help` for help with the other paramaters + """ + + for d in self.delta_phi, self.delta_F0, self.delta_F1, self.delta_F2: + if np.size(d) == 1: + d = [d] + self.tend = self.tstart + self.duration + if self.dtglitch is None or self.dtglitch == self.duration: + self.tbounds = [self.tstart, self.tend] + elif np.size(self.dtglitch) == 1: + self.tbounds = [self.tstart, self.tstart+self.dtglitch, self.tend] + else: + self.tglitch = self.tstart + np.array(self.dtglitch) + self.tbounds = [self.tstart] + list(self.tglitch) + [self.tend] + + if os.path.isdir(self.outdir) is False: + os.makedirs(self.outdir) + if self.tref is None: + self.tref = self.tstart + self.tend = self.tstart + self.duration + tbs = np.array(self.tbounds) + self.durations_days = (tbs[1:] - tbs[:-1]) / 86400 + self.config_file_name = "{}/{}.cff".format(outdir, label) + + self.theta = np.array([phi, F0, F1, F2]) + self.delta_thetas = np.atleast_2d( + np.array([delta_phi, delta_F0, delta_F1, delta_F2]).T) + + self.detector = 'H1' + numSFTs = int(float(self.duration) / self.Tsft) + self.sft_filename = lalpulsar.OfficialSFTFilename( + 'H', '1', numSFTs, self.Tsft, self.tstart, self.duration, + self.label) + self.sft_filepath = '{}/{}'.format(self.outdir, self.sft_filename) + self.calculate_fmin_Band() + + def make_data(self): + ''' A convienience wrapper to generate a cff file then sfts ''' + self.make_cff() + self.run_makefakedata() + + def get_single_config_line(self, i, Alpha, Delta, h0, cosi, psi, phi, F0, + F1, F2, tref, tstart, duration_days): + template = ( +"""[TS{}] +Alpha = {:1.18e} +Delta = {:1.18e} +h0 = {:1.18e} +cosi = {:1.18e} +psi = {:1.18e} +phi0 = {:1.18e} +Freq = {:1.18e} +f1dot = {:1.18e} +f2dot = {:1.18e} +refTime = {:10.6f} +transientWindowType=rect +transientStartTime={:10.3f} +transientTauDays={:1.3f}\n""") + return template.format(i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, + F2, tref, tstart, duration_days) + + def make_cff(self): + """ + Generates an .cff file for a 'glitching' signal + + """ + + thetas = self.calculate_thetas(self.theta, self.delta_thetas, + self.tbounds) + + content = '' + for i, (t, d, ts) in enumerate(zip(thetas, self.durations_days, + self.tbounds[:-1])): + line = self.get_single_config_line( + i, self.Alpha, self.Delta, self.h0, self.cosi, self.psi, + t[0], t[1], t[2], t[3], self.tref, ts, d) + + content += line + + if self.check_if_cff_file_needs_rewritting(content): + config_file = open(self.config_file_name, "w+") + config_file.write(content) + config_file.close() + + def calculate_fmin_Band(self): + self.fmin = self.F0 - .5 * self.Band + + def check_cached_data_okay_to_use(self, cl): + """ Check if cached data exists and, if it does, if it can be used """ + + getmtime = os.path.getmtime + + if os.path.isfile(self.sft_filepath) is False: + logging.info('No SFT file matching {} found'.format( + self.sft_filepath)) + return False + else: + logging.info('Matching SFT file found') + + if getmtime(self.sft_filepath) < getmtime(self.config_file_name): + logging.info( + ('The config file {} has been modified since the sft file {} ' + + 'was created').format( + self.config_file_name, self.sft_filepath)) + return False + + logging.info( + 'The config file {} is older than the sft file {}'.format( + self.config_file_name, self.sft_filepath)) + logging.info('Checking contents of cff file') + logging.info('Execute: {}'.format( + 'lalapps_SFTdumpheader {} | head -n 20'.format(self.sft_filepath))) + output = subprocess.check_output( + 'lalapps_SFTdumpheader {} | head -n 20'.format(self.sft_filepath), + shell=True) + calls = [line for line in output.split('\n') if line[:3] == 'lal'] + if calls[0] == cl: + logging.info('Contents matched, use old sft file') + return True + else: + logging.info('Contents unmatched, create new sft file') + return False + + def check_if_cff_file_needs_rewritting(self, content): + """ Check if the .cff file has changed + + Returns True if the file should be overwritten - where possible avoid + overwriting to allow cached data to be used + """ + if os.path.isfile(self.config_file_name) is False: + logging.info('No config file {} found'.format( + self.config_file_name)) + return True + else: + logging.info('Config file {} already exists'.format( + self.config_file_name)) + + with open(self.config_file_name, 'r') as f: + file_content = f.read() + if file_content == content: + logging.info( + 'File contents match, no update of {} required'.format( + self.config_file_name)) + return False + else: + logging.info( + 'File contents unmatched, updating {}'.format( + self.config_file_name)) + return True + + def run_makefakedata(self): + """ Generate the sft data from the configuration file """ + + # Remove old data: + try: + os.unlink("{}/*{}*.sft".format(self.outdir, self.label)) + except OSError: + pass + + cl = [] + cl.append('lalapps_Makefakedata_v5') + cl.append('--outSingleSFT=TRUE') + cl.append('--outSFTdir="{}"'.format(self.outdir)) + cl.append('--outLabel="{}"'.format(self.label)) + cl.append('--IFOs="{}"'.format(self.detector)) + cl.append('--sqrtSX="{}"'.format(self.sqrtSX)) + cl.append('--startTime={:10.9f}'.format(float(self.tstart))) + cl.append('--duration={}'.format(int(self.duration))) + cl.append('--fmin={}'.format(int(self.fmin))) + cl.append('--Band={}'.format(self.Band)) + cl.append('--Tsft={}'.format(self.Tsft)) + cl.append('--injectionSources="./{}"'.format(self.config_file_name)) + + cl = " ".join(cl) + + if self.check_cached_data_okay_to_use(cl) is False: + logging.info("Executing: " + cl) + os.system(cl) + os.system('\n') + + def predict_fstat(self): + """ Wrapper to lalapps_PredictFstat """ + c_l = [] + c_l.append("lalapps_PredictFstat") + c_l.append("--h0={}".format(self.h0)) + c_l.append("--cosi={}".format(self.cosi)) + c_l.append("--psi={}".format(self.psi)) + c_l.append("--Alpha={}".format(self.Alpha)) + c_l.append("--Delta={}".format(self.Delta)) + c_l.append("--Freq={}".format(self.F0)) + + c_l.append("--DataFiles='{}'".format( + self.outdir+"/*SFT_"+self.label+"*sft")) + c_l.append("--assumeSqrtSX={}".format(self.sqrtSX)) + + c_l.append("--minStartTime={}".format(self.tstart)) + c_l.append("--maxStartTime={}".format(self.tend)) + + logging.info("Executing: " + " ".join(c_l) + "\n") + output = subprocess.check_output(" ".join(c_l), shell=True) + twoF = float(output.split('\n')[-2]) + return float(twoF) diff --git a/tests/tests.py b/tests/tests.py new file mode 100644 index 0000000000000000000000000000000000000000..7f90a7d5166ddadbbebfeeb77c3c5be782d591cc --- /dev/null +++ b/tests/tests.py @@ -0,0 +1,272 @@ +import unittest +import pyfstat +import numpy as np +import os + + +class TestWriter(unittest.TestCase): + + def test_make_cff(self): + label = "Test" + Writer = pyfstat.Writer(label, outdir='TestData') + Writer.make_cff() + self.assertTrue(os.path.isfile('./TestData/Test.cff')) + + def test_run_makefakedata(self): + label = "Test" + Writer = pyfstat.Writer(label, outdir='TestData') + Writer.make_cff() + Writer.run_makefakedata() + self.assertTrue(os.path.isfile( + './TestData/H-4800_H1_1800SFT_Test-700000000-8640000.sft')) + + def test_makefakedata_usecached(self): + label = "Test" + Writer = pyfstat.Writer(label, outdir='TestData') + if os.path.isfile(Writer.sft_filepath): + os.remove(Writer.sft_filepath) + Writer.run_makefakedata() + time_first = os.path.getmtime(Writer.sft_filepath) + Writer.run_makefakedata() + time_second = os.path.getmtime(Writer.sft_filepath) + self.assertTrue(time_first == time_second) + os.system('touch {}'.format(Writer.config_file_name)) + Writer.run_makefakedata() + time_third = os.path.getmtime(Writer.sft_filepath) + self.assertFalse(time_first == time_third) + + +class TestBaseSearchClass(unittest.TestCase): + def test_shift_matrix(self): + BSC = pyfstat.BaseSearchClass() + dT = 10 + a = BSC.shift_matrix(4, dT) + b = np.array([[1, 2*np.pi*dT, 2*np.pi*dT**2/2.0, 2*np.pi*dT**3/6.0], + [0, 1, dT, dT**2/2.0], + [0, 0, 1, dT], + [0, 0, 0, 1]]) + self.assertTrue(np.array_equal(a, b)) + + def test_shift_coefficients(self): + BSC = pyfstat.BaseSearchClass() + thetaA = np.array([10., 1e2, 10., 1e2]) + dT = 100 + + # Calculate the 'long' way + thetaB = np.zeros(len(thetaA)) + thetaB[3] = thetaA[3] + thetaB[2] = thetaA[2] + thetaA[3]*dT + thetaB[1] = thetaA[1] + thetaA[2]*dT + .5*thetaA[3]*dT**2 + thetaB[0] = thetaA[0] + 2*np.pi*(thetaA[1]*dT + .5*thetaA[2]*dT**2 + + thetaA[3]*dT**3 / 6.0) + + self.assertTrue( + np.array_equal( + thetaB, BSC.shift_coefficients(thetaA, dT))) + + def test_shift_coefficients_loop(self): + BSC = pyfstat.BaseSearchClass() + thetaA = np.array([10., 1e2, 10., 1e2]) + dT = 1e1 + thetaB = BSC.shift_coefficients(thetaA, dT) + self.assertTrue( + np.allclose( + thetaA, BSC.shift_coefficients(thetaB, -dT), + rtol=1e-9, atol=1e-9)) + + +class TestFullyCoherentNarrowBandSearch(unittest.TestCase): + label = "Test" + outdir = 'TestData' + + def test_compute_fstat(self): + Writer = glitch_tools.Writer(self.label, outdir=self.outdir) + Writer.make_data() + + search = glitch_searches.FullyCoherentNarrowBandSearch( + self.label, self.outdir, tref=Writer.tref, Alpha=Writer.Alpha, + Delta=Writer.Delta, duration=Writer.duration, tstart=Writer.tstart, + Writer=Writer) + search.run_computefstatistic_slow(m=1e-3, n=0) + _, _, _, FS_max_slow = search.get_FS_max() + + search.run_computefstatistic(dFreq=0, numFreqBins=1) + _, _, _, FS_max = search.get_FS_max() + self.assertTrue( + np.abs(FS_max-FS_max_slow)/FS_max_slow < 0.1) + + def test_compute_fstat_against_predict_fstat(self): + Writer = glitch_tools.Writer(self.label, outdir=self.outdir) + Writer.make_data() + Writer.run_makefakedata() + predicted_FS = Writer.predict_fstat() + + search = glitch_searches.FullyCoherentNarrowBandSearch( + self.label, self.outdir, tref=Writer.tref, Alpha=Writer.Alpha, + Delta=Writer.Delta, duration=Writer.duration, tstart=Writer.tstart, + Writer=Writer) + search.run_computefstatistic(dFreq=0, numFreqBins=1) + _, _, _, FS_max = search.get_FS_max() + self.assertTrue(np.abs(predicted_FS-FS_max)/FS_max < 0.5) + + +class TestSemiCoherentGlitchSearch(unittest.TestCase): + label = "Test" + outdir = 'TestData' + + def test_run_computefstatistic_single_point(self): + Writer = pyfstat.Writer(self.label, outdir=self.outdir) + Writer.make_data() + predicted_FS = Writer.predict_fstat() + + search = pyfstat.SemiCoherentGlitchSearch( + label=Writer.label, outdir=Writer.outdir, tref=Writer.tref, + tstart=Writer.tstart, tend=Writer.tend) + FS = search.run_computefstatistic_single_point(search.tref, + search.tstart, + search.tend, + Writer.F0, + Writer.F1, + Writer.F2, + Writer.Alpha, + Writer.Delta) + print predicted_FS, FS + self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.1) + + def test_run_computefstatistic_single_point_slow(self): + Writer = pyfstat.Writer(self.label, outdir=self.outdir) + Writer.make_data() + predicted_FS = Writer.predict_fstat() + + search = pyfstat.SemiCoherentGlitchSearch( + label=Writer.label, outdir=Writer.outdir, tref=Writer.tref, + tstart=Writer.tstart, tend=Writer.tend) + FS = search.run_computefstatistic_single_point_slow(search.tref, + search.tstart, + search.tend, + Writer.F0, + Writer.F1, + Writer.F2, + Writer.Alpha, + Writer.Delta) + self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.1) + + def test_compute_glitch_fstat_slow(self): + duration = 100*86400 + dtglitch = 100*43200 + delta_F0 = 0 + Writer = pyfstat.Writer(self.label, outdir=self.outdir, + duration=duration, dtglitch=dtglitch, + delta_F0=delta_F0) + Writer.make_data() + + search = pyfstat.SemiCoherentGlitchSearch( + label=Writer.label, outdir=Writer.outdir, tref=Writer.tref, + tstart=Writer.tstart, tend=Writer.tend) + + FS = search.compute_glitch_fstat_slow(Writer.F0, Writer.F1, Writer.F2, + Writer.Alpha, Writer.Delta, + Writer.delta_F0, Writer.delta_F1, + Writer.tglitch) + + # Compute the predicted semi-coherent glitch Fstat + tstart = Writer.tstart + tend = Writer.tend + + Writer.tend = tstart + dtglitch + FSA = Writer.predict_fstat() + + Writer.tstart = tstart + dtglitch + Writer.tend = tend + FSB = Writer.predict_fstat() + + predicted_FS = .5*(FSA + FSB) + + print(predicted_FS, FS) + self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.1) + + def test_compute_nglitch_fstat(self): + duration = 100*86400 + dtglitch = 100*43200 + delta_F0 = 0 + Writer = pyfstat.Writer(self.label, outdir=self.outdir, + duration=duration, dtglitch=dtglitch, + delta_F0=delta_F0) + + Writer.make_data() + + search = pyfstat.SemiCoherentGlitchSearch( + label=Writer.label, outdir=Writer.outdir, tref=Writer.tref, + tstart=Writer.tstart, tend=Writer.tend, nglitch=1) + + FS = search.compute_nglitch_fstat(Writer.F0, Writer.F1, Writer.F2, + Writer.Alpha, Writer.Delta, + Writer.delta_F0, Writer.delta_F1, + search.tstart+dtglitch) + + # Compute the predicted semi-coherent glitch Fstat + tstart = Writer.tstart + tend = Writer.tend + + Writer.tend = tstart + dtglitch + FSA = Writer.predict_fstat() + + Writer.tstart = tstart + dtglitch + Writer.tend = tend + FSB = Writer.predict_fstat() + + predicted_FS = (FSA + FSB) + + print(predicted_FS, FS) + self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.1) + + +class TestMCMCGlitchSearch(unittest.TestCase): + label = "MCMCTest" + outdir = 'TestData' + + def test_fully_coherent(self): + h0 = 1e-24 + sqrtSX = 1e-22 + F0 = 30 + F1 = -1e-10 + F2 = 0 + tstart = 700000000 + duration = 100 * 86400 + tend = tstart + duration + Alpha = 5e-3 + Delta = 1.2 + tref = tstart + dtglitch = duration + delta_F0 = 0 + Writer = pyfstat.Writer(F0=F0, F1=F1, F2=F2, label=self.label, + h0=h0, sqrtSX=sqrtSX, + outdir=self.outdir, tstart=tstart, + Alpha=Alpha, Delta=Delta, tref=tref, + duration=duration, dtglitch=dtglitch, + delta_F0=delta_F0, Band=4) + + Writer.make_data() + predicted_FS = Writer.predict_fstat() + + theta = {'delta_F0': 0, 'delta_F1': 0, 'tglitch': tend, + 'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)}, + 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-9*F1)}, + 'F2': F2, 'Alpha': Alpha, 'Delta': Delta} + + search = pyfstat.MCMCGlitchSearch( + label=self.label, outdir=self.outdir, theta=theta, tref=tref, + sftlabel=self.label, sftdir=self.outdir, + tstart=tstart, tend=tend, nsteps=[100, 100], nwalkers=100, + ntemps=1) + search.run() + search.plot_corner(add_prior=True) + _, FS = search.get_max_twoF() + + print('Predicted twoF is {} while recovered is {}'.format( + predicted_FS, FS)) + self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.1) + + +if __name__ == '__main__': + unittest.main()