Commit 53a140d5 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adds tests and some fixes to make the important tests pass

parent 87426577
......@@ -8,6 +8,7 @@ import copy
import glob
import inspect
from functools import wraps
import subprocess
import numpy as np
import matplotlib
......@@ -25,7 +26,7 @@ if os.path.isfile(config_file):
for line in f:
k, v = line.split('=')
k = k.replace(' ', '')
v = v.replace(' ', '')
v = v.replace(' ', '').replace("'", "").replace('"', '').replace('\n', '')
d[k] = v
earth_ephem = d['earth_ephem']
sun_ephem = d['sun_ephem']
......@@ -37,6 +38,24 @@ else:
plt.style.use('paper')
parser = argparse.ArgumentParser()
parser.add_argument("-q", "--quite", help="Decrease output verbosity",
action="store_true")
parser.add_argument("-c", "--clean", help="Don't use cached data",
action="store_true")
parser.add_argument('unittest_args', nargs='*')
args, unknown = parser.parse_known_args()
sys.argv[1:] = args.unittest_args
if args.quite:
log_level = logging.WARNING
else:
log_level = logging.DEBUG
logging.basicConfig(level=log_level,
format='%(asctime)s %(levelname)-8s: %(message)s',
datefmt='%H:%M')
def initializer(func):
""" Automatically assigns the parameters to self"""
......@@ -65,24 +84,6 @@ def read_par(label, outdir):
d[key] = np.float64(val)
return d
parser = argparse.ArgumentParser()
parser.add_argument("-q", "--quite", help="Decrease output verbosity",
action="store_true")
parser.add_argument("-c", "--clean", help="Don't use cached data",
action="store_true")
parser.add_argument('unittest_args', nargs='*')
args, unknown = parser.parse_known_args()
sys.argv[1:] = args.unittest_args
if args.quite:
log_level = logging.WARNING
else:
log_level = logging.DEBUG
logging.basicConfig(level=log_level,
format='%(asctime)s %(levelname)-8s: %(message)s',
datefmt='%H:%M')
class BaseSearchClass(object):
......@@ -419,7 +420,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass):
@initializer
def __init__(self, label, outdir, tref, tstart, tend, sftlabel=None,
nglitch=0, sftdir=None, minCoverFreq=29, maxCoverFreq=31,
nglitch=0, sftdir=None, minCoverFreq=None, maxCoverFreq=None,
detector=None, earth_ephem=None, sun_ephem=None):
"""
Parameters
......@@ -432,7 +433,9 @@ class SemiCoherentGlitchSearch(BaseSearchClass):
tref: int
GPS seconds of the reference time
minCoverFreq, maxCoverFreq: float
The min and max cover frequency passed to CreateFstatInput
The min and max cover frequency passed to CreateFstatInput, if
either is None the range of frequencies in the SFT less 1Hz is
used.
detector: str
Two character reference to the data to use, specify None for no
contraint
......@@ -531,6 +534,14 @@ class SemiCoherentGlitchSearch(BaseSearchClass):
dFreq = 0
self.whatToCompute = lalpulsar.FSTATQ_ATOMS_PER_DET
FstatOptionalArgs = lalpulsar.FstatOptionalArgsDefaults
if self.minCoverFreq is None or self.maxCoverFreq is None:
fA = SFTCatalog.data[0].header.f0
numBins = SFTCatalog.data[0].numBins
fB = fA + (numBins-1)*SFTCatalog.data[0].header.deltaF
self.minCoverFreq = fA + 0.5
self.maxCoverFreq = fB - 0.5
self.FstatInput = lalpulsar.CreateFstatInput(SFTCatalog,
self.minCoverFreq,
self.maxCoverFreq,
......@@ -1323,3 +1334,235 @@ class GridGlitchSearch(BaseSearchClass):
twoF = self.data[:, -1]
return np.max(twoF)
class Writer(BaseSearchClass):
""" Instance object for generating SFTs containing glitch signals """
@initializer
def __init__(self, label='Test', tstart=700000000, duration=100*86400,
dtglitch=None,
delta_phi=0, delta_F0=0, delta_F1=0, delta_F2=0,
tref=None, phi=0, F0=30, F1=1e-10, F2=0, Alpha=5e-3,
Delta=6e-2, h0=0.1, cosi=0.0, psi=0.0, Tsft=1800, outdir=".",
sqrtSX=1, Band=4):
"""
Parameters
----------
label: string
a human-readable label to be used in naming the output files
tstart, tend : float
start and end times (in gps seconds) of the total observation span
dtglitch: float
time (in gps seconds) of the glitch after tstart. To create data
without a glitch, set dtglitch=tend-tstart or leave as None
delta_phi, delta_F0, delta_F1: float
instanteneous glitch magnitudes in rad, Hz, and Hz/s respectively
tref: float or None
reference time (default is None, which sets the reference time to
tstart)
phil, F0, F1, F2, Alpha, Delta, h0, cosi, psi: float
pre-glitch phase, frequency, sky-position, and signal properties
Tsft: float
the sft duration
see `lalapps_Makefakedata_v5 --help` for help with the other paramaters
"""
for d in self.delta_phi, self.delta_F0, self.delta_F1, self.delta_F2:
if np.size(d) == 1:
d = [d]
self.tend = self.tstart + self.duration
if self.dtglitch is None or self.dtglitch == self.duration:
self.tbounds = [self.tstart, self.tend]
elif np.size(self.dtglitch) == 1:
self.tbounds = [self.tstart, self.tstart+self.dtglitch, self.tend]
else:
self.tglitch = self.tstart + np.array(self.dtglitch)
self.tbounds = [self.tstart] + list(self.tglitch) + [self.tend]
if os.path.isdir(self.outdir) is False:
os.makedirs(self.outdir)
if self.tref is None:
self.tref = self.tstart
self.tend = self.tstart + self.duration
tbs = np.array(self.tbounds)
self.durations_days = (tbs[1:] - tbs[:-1]) / 86400
self.config_file_name = "{}/{}.cff".format(outdir, label)
self.theta = np.array([phi, F0, F1, F2])
self.delta_thetas = np.atleast_2d(
np.array([delta_phi, delta_F0, delta_F1, delta_F2]).T)
self.detector = 'H1'
numSFTs = int(float(self.duration) / self.Tsft)
self.sft_filename = lalpulsar.OfficialSFTFilename(
'H', '1', numSFTs, self.Tsft, self.tstart, self.duration,
self.label)
self.sft_filepath = '{}/{}'.format(self.outdir, self.sft_filename)
self.calculate_fmin_Band()
def make_data(self):
''' A convienience wrapper to generate a cff file then sfts '''
self.make_cff()
self.run_makefakedata()
def get_single_config_line(self, i, Alpha, Delta, h0, cosi, psi, phi, F0,
F1, F2, tref, tstart, duration_days):
template = (
"""[TS{}]
Alpha = {:1.18e}
Delta = {:1.18e}
h0 = {:1.18e}
cosi = {:1.18e}
psi = {:1.18e}
phi0 = {:1.18e}
Freq = {:1.18e}
f1dot = {:1.18e}
f2dot = {:1.18e}
refTime = {:10.6f}
transientWindowType=rect
transientStartTime={:10.3f}
transientTauDays={:1.3f}\n""")
return template.format(i, Alpha, Delta, h0, cosi, psi, phi, F0, F1,
F2, tref, tstart, duration_days)
def make_cff(self):
"""
Generates an .cff file for a 'glitching' signal
"""
thetas = self.calculate_thetas(self.theta, self.delta_thetas,
self.tbounds)
content = ''
for i, (t, d, ts) in enumerate(zip(thetas, self.durations_days,
self.tbounds[:-1])):
line = self.get_single_config_line(
i, self.Alpha, self.Delta, self.h0, self.cosi, self.psi,
t[0], t[1], t[2], t[3], self.tref, ts, d)
content += line
if self.check_if_cff_file_needs_rewritting(content):
config_file = open(self.config_file_name, "w+")
config_file.write(content)
config_file.close()
def calculate_fmin_Band(self):
self.fmin = self.F0 - .5 * self.Band
def check_cached_data_okay_to_use(self, cl):
""" Check if cached data exists and, if it does, if it can be used """
getmtime = os.path.getmtime
if os.path.isfile(self.sft_filepath) is False:
logging.info('No SFT file matching {} found'.format(
self.sft_filepath))
return False
else:
logging.info('Matching SFT file found')
if getmtime(self.sft_filepath) < getmtime(self.config_file_name):
logging.info(
('The config file {} has been modified since the sft file {} '
+ 'was created').format(
self.config_file_name, self.sft_filepath))
return False
logging.info(
'The config file {} is older than the sft file {}'.format(
self.config_file_name, self.sft_filepath))
logging.info('Checking contents of cff file')
logging.info('Execute: {}'.format(
'lalapps_SFTdumpheader {} | head -n 20'.format(self.sft_filepath)))
output = subprocess.check_output(
'lalapps_SFTdumpheader {} | head -n 20'.format(self.sft_filepath),
shell=True)
calls = [line for line in output.split('\n') if line[:3] == 'lal']
if calls[0] == cl:
logging.info('Contents matched, use old sft file')
return True
else:
logging.info('Contents unmatched, create new sft file')
return False
def check_if_cff_file_needs_rewritting(self, content):
""" Check if the .cff file has changed
Returns True if the file should be overwritten - where possible avoid
overwriting to allow cached data to be used
"""
if os.path.isfile(self.config_file_name) is False:
logging.info('No config file {} found'.format(
self.config_file_name))
return True
else:
logging.info('Config file {} already exists'.format(
self.config_file_name))
with open(self.config_file_name, 'r') as f:
file_content = f.read()
if file_content == content:
logging.info(
'File contents match, no update of {} required'.format(
self.config_file_name))
return False
else:
logging.info(
'File contents unmatched, updating {}'.format(
self.config_file_name))
return True
def run_makefakedata(self):
""" Generate the sft data from the configuration file """
# Remove old data:
try:
os.unlink("{}/*{}*.sft".format(self.outdir, self.label))
except OSError:
pass
cl = []
cl.append('lalapps_Makefakedata_v5')
cl.append('--outSingleSFT=TRUE')
cl.append('--outSFTdir="{}"'.format(self.outdir))
cl.append('--outLabel="{}"'.format(self.label))
cl.append('--IFOs="{}"'.format(self.detector))
cl.append('--sqrtSX="{}"'.format(self.sqrtSX))
cl.append('--startTime={:10.9f}'.format(float(self.tstart)))
cl.append('--duration={}'.format(int(self.duration)))
cl.append('--fmin={}'.format(int(self.fmin)))
cl.append('--Band={}'.format(self.Band))
cl.append('--Tsft={}'.format(self.Tsft))
cl.append('--injectionSources="./{}"'.format(self.config_file_name))
cl = " ".join(cl)
if self.check_cached_data_okay_to_use(cl) is False:
logging.info("Executing: " + cl)
os.system(cl)
os.system('\n')
def predict_fstat(self):
""" Wrapper to lalapps_PredictFstat """
c_l = []
c_l.append("lalapps_PredictFstat")
c_l.append("--h0={}".format(self.h0))
c_l.append("--cosi={}".format(self.cosi))
c_l.append("--psi={}".format(self.psi))
c_l.append("--Alpha={}".format(self.Alpha))
c_l.append("--Delta={}".format(self.Delta))
c_l.append("--Freq={}".format(self.F0))
c_l.append("--DataFiles='{}'".format(
self.outdir+"/*SFT_"+self.label+"*sft"))
c_l.append("--assumeSqrtSX={}".format(self.sqrtSX))
c_l.append("--minStartTime={}".format(self.tstart))
c_l.append("--maxStartTime={}".format(self.tend))
logging.info("Executing: " + " ".join(c_l) + "\n")
output = subprocess.check_output(" ".join(c_l), shell=True)
twoF = float(output.split('\n')[-2])
return float(twoF)
import unittest
import pyfstat
import numpy as np
import os
class TestWriter(unittest.TestCase):
def test_make_cff(self):
label = "Test"
Writer = pyfstat.Writer(label, outdir='TestData')
Writer.make_cff()
self.assertTrue(os.path.isfile('./TestData/Test.cff'))
def test_run_makefakedata(self):
label = "Test"
Writer = pyfstat.Writer(label, outdir='TestData')
Writer.make_cff()
Writer.run_makefakedata()
self.assertTrue(os.path.isfile(
'./TestData/H-4800_H1_1800SFT_Test-700000000-8640000.sft'))
def test_makefakedata_usecached(self):
label = "Test"
Writer = pyfstat.Writer(label, outdir='TestData')
if os.path.isfile(Writer.sft_filepath):
os.remove(Writer.sft_filepath)
Writer.run_makefakedata()
time_first = os.path.getmtime(Writer.sft_filepath)
Writer.run_makefakedata()
time_second = os.path.getmtime(Writer.sft_filepath)
self.assertTrue(time_first == time_second)
os.system('touch {}'.format(Writer.config_file_name))
Writer.run_makefakedata()
time_third = os.path.getmtime(Writer.sft_filepath)
self.assertFalse(time_first == time_third)
class TestBaseSearchClass(unittest.TestCase):
def test_shift_matrix(self):
BSC = pyfstat.BaseSearchClass()
dT = 10
a = BSC.shift_matrix(4, dT)
b = np.array([[1, 2*np.pi*dT, 2*np.pi*dT**2/2.0, 2*np.pi*dT**3/6.0],
[0, 1, dT, dT**2/2.0],
[0, 0, 1, dT],
[0, 0, 0, 1]])
self.assertTrue(np.array_equal(a, b))
def test_shift_coefficients(self):
BSC = pyfstat.BaseSearchClass()
thetaA = np.array([10., 1e2, 10., 1e2])
dT = 100
# Calculate the 'long' way
thetaB = np.zeros(len(thetaA))
thetaB[3] = thetaA[3]
thetaB[2] = thetaA[2] + thetaA[3]*dT
thetaB[1] = thetaA[1] + thetaA[2]*dT + .5*thetaA[3]*dT**2
thetaB[0] = thetaA[0] + 2*np.pi*(thetaA[1]*dT + .5*thetaA[2]*dT**2
+ thetaA[3]*dT**3 / 6.0)
self.assertTrue(
np.array_equal(
thetaB, BSC.shift_coefficients(thetaA, dT)))
def test_shift_coefficients_loop(self):
BSC = pyfstat.BaseSearchClass()
thetaA = np.array([10., 1e2, 10., 1e2])
dT = 1e1
thetaB = BSC.shift_coefficients(thetaA, dT)
self.assertTrue(
np.allclose(
thetaA, BSC.shift_coefficients(thetaB, -dT),
rtol=1e-9, atol=1e-9))
class TestFullyCoherentNarrowBandSearch(unittest.TestCase):
label = "Test"
outdir = 'TestData'
def test_compute_fstat(self):
Writer = glitch_tools.Writer(self.label, outdir=self.outdir)
Writer.make_data()
search = glitch_searches.FullyCoherentNarrowBandSearch(
self.label, self.outdir, tref=Writer.tref, Alpha=Writer.Alpha,
Delta=Writer.Delta, duration=Writer.duration, tstart=Writer.tstart,
Writer=Writer)
search.run_computefstatistic_slow(m=1e-3, n=0)
_, _, _, FS_max_slow = search.get_FS_max()
search.run_computefstatistic(dFreq=0, numFreqBins=1)
_, _, _, FS_max = search.get_FS_max()
self.assertTrue(
np.abs(FS_max-FS_max_slow)/FS_max_slow < 0.1)
def test_compute_fstat_against_predict_fstat(self):
Writer = glitch_tools.Writer(self.label, outdir=self.outdir)
Writer.make_data()
Writer.run_makefakedata()
predicted_FS = Writer.predict_fstat()
search = glitch_searches.FullyCoherentNarrowBandSearch(
self.label, self.outdir, tref=Writer.tref, Alpha=Writer.Alpha,
Delta=Writer.Delta, duration=Writer.duration, tstart=Writer.tstart,
Writer=Writer)
search.run_computefstatistic(dFreq=0, numFreqBins=1)
_, _, _, FS_max = search.get_FS_max()
self.assertTrue(np.abs(predicted_FS-FS_max)/FS_max < 0.5)
class TestSemiCoherentGlitchSearch(unittest.TestCase):
label = "Test"
outdir = 'TestData'
def test_run_computefstatistic_single_point(self):
Writer = pyfstat.Writer(self.label, outdir=self.outdir)
Writer.make_data()
predicted_FS = Writer.predict_fstat()
search = pyfstat.SemiCoherentGlitchSearch(
label=Writer.label, outdir=Writer.outdir, tref=Writer.tref,
tstart=Writer.tstart, tend=Writer.tend)
FS = search.run_computefstatistic_single_point(search.tref,
search.tstart,
search.tend,
Writer.F0,
Writer.F1,
Writer.F2,
Writer.Alpha,
Writer.Delta)
print predicted_FS, FS
self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.1)
def test_run_computefstatistic_single_point_slow(self):
Writer = pyfstat.Writer(self.label, outdir=self.outdir)
Writer.make_data()
predicted_FS = Writer.predict_fstat()
search = pyfstat.SemiCoherentGlitchSearch(
label=Writer.label, outdir=Writer.outdir, tref=Writer.tref,
tstart=Writer.tstart, tend=Writer.tend)
FS = search.run_computefstatistic_single_point_slow(search.tref,
search.tstart,
search.tend,
Writer.F0,
Writer.F1,
Writer.F2,
Writer.Alpha,
Writer.Delta)
self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.1)
def test_compute_glitch_fstat_slow(self):
duration = 100*86400
dtglitch = 100*43200
delta_F0 = 0
Writer = pyfstat.Writer(self.label, outdir=self.outdir,
duration=duration, dtglitch=dtglitch,
delta_F0=delta_F0)
Writer.make_data()
search = pyfstat.SemiCoherentGlitchSearch(
label=Writer.label, outdir=Writer.outdir, tref=Writer.tref,
tstart=Writer.tstart, tend=Writer.tend)
FS = search.compute_glitch_fstat_slow(Writer.F0, Writer.F1, Writer.F2,
Writer.Alpha, Writer.Delta,
Writer.delta_F0, Writer.delta_F1,
Writer.tglitch)
# Compute the predicted semi-coherent glitch Fstat
tstart = Writer.tstart
tend = Writer.tend
Writer.tend = tstart + dtglitch
FSA = Writer.predict_fstat()
Writer.tstart = tstart + dtglitch
Writer.tend = tend
FSB = Writer.predict_fstat()
predicted_FS = .5*(FSA + FSB)
print(predicted_FS, FS)
self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.1)
def test_compute_nglitch_fstat(self):
duration = 100*86400
dtglitch = 100*43200
delta_F0 = 0
Writer = pyfstat.Writer(self.label, outdir=self.outdir,
duration=duration, dtglitch=dtglitch,
delta_F0=delta_F0)
Writer.make_data()
search = pyfstat.SemiCoherentGlitchSearch(
label=Writer.label, outdir=Writer.outdir, tref=Writer.tref,
tstart=Writer.tstart, tend=Writer.tend, nglitch=1)
FS = search.compute_nglitch_fstat(Writer.F0, Writer.F1, Writer.F2,
Writer.Alpha, Writer.Delta,
Writer.delta_F0, Writer.delta_F1,
search.tstart+dtglitch)
# Compute the predicted semi-coherent glitch Fstat
tstart = Writer.tstart
tend = Writer.tend
Writer.tend = tstart + dtglitch
FSA = Writer.predict_fstat()
Writer.tstart = tstart + dtglitch
Writer.tend = tend
FSB = Writer.predict_fstat()
predicted_FS = (FSA + FSB)
print(predicted_FS, FS)
self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.1)
class TestMCMCGlitchSearch(unittest.TestCase):
label = "MCMCTest"
outdir = 'TestData'
def test_fully_coherent(self):
h0 = 1e-24
sqrtSX = 1e-22
F0 = 30
F1 = -1e-10
F2 = 0
tstart = 700000000
duration = 100 * 86400
tend = tstart + duration
Alpha = 5e-3
Delta = 1.2
tref = tstart
dtglitch = duration
delta_F0 = 0
Writer = pyfstat.Writer(F0=F0, F1=F1, F2=F2, label=self.label,
h0=h0, sqrtSX=sqrtSX,
outdir=self.outdir, tstart=tstart,
Alpha=Alpha, Delta=Delta, tref=tref,
duration=duration, dtglitch=dtglitch,
delta_F0=delta_F0, Band=4)
Writer.make_data()
predicted_FS = Writer.predict_fstat()
theta = {'delta_F0': 0, 'delta_F1': 0, 'tglitch': tend,
'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)},
'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-9*F1)},
'F2': F2, 'Alpha': Alpha, 'Delta': Delta}
search = pyfstat.MCMCGlitchSearch(
label=self.label, outdir=self.outdir, theta=theta, tref=tref,
sftlabel=self.label, sftdir=self.outdir,
tstart=tstart, tend=tend, nsteps=[100, 100], nwalkers=100,
ntemps=1)
search.run()
search.plot_corner(add_prior=True)
_, FS = search.get_max_twoF()
print('Predicted twoF is {} while recovered is {}'.format(
predicted_FS, FS))
self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.1)
if __name__ == '__main__':