Skip to content
Snippets Groups Projects
Commit 89c44b66 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Reformatting of search codes

Create a separate class for pure ComputeFstat wrapper, then gives this
as a parent to the other searches. Ultimately, this will allow the use
of the ComputeFstat class without needing the glitch part, for example
for a transient search.
parent 7cb7b01e
No related branches found
No related tags found
No related merge requests found
...@@ -141,7 +141,101 @@ class BaseSearchClass(object): ...@@ -141,7 +141,101 @@ class BaseSearchClass(object):
return thetas return thetas
class SemiCoherentGlitchSearch(BaseSearchClass): class ComputeFstat(object):
""" Base class providing interface to lalpulsar.ComputeFstat """
earth_ephem_default = earth_ephem
sun_ephem_default = sun_ephem
@initializer
def __init__(self, tref, sftlabel=None, sftdir=None,
minCoverFreq=None, maxCoverFreq=None,
detector=None, earth_ephem=None, sun_ephem=None):
if earth_ephem is None:
self.earth_ephem = self.earth_ephem_default
if sun_ephem is None:
self.sun_ephem = self.sun_ephem_default
self.init_computefstatistic_single_point()
def init_computefstatistic_single_point(self):
""" Initilisation step of run_computefstatistic for a single point """
logging.info('Initialising SFTCatalog')
constraints = lalpulsar.SFTConstraints()
if self.detector:
constraints.detector = self.detector
self.sft_filepath = self.sftdir+'/*_'+self.sftlabel+"*sft"
SFTCatalog = lalpulsar.SFTdataFind(self.sft_filepath, constraints)
names = list(set([d.header.name for d in SFTCatalog.data]))
logging.info('Loaded data from detectors {}'.format(names))
logging.info('Initialising ephems')
ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem)
logging.info('Initialising FstatInput')
dFreq = 0
self.whatToCompute = lalpulsar.FSTATQ_ATOMS_PER_DET
FstatOptionalArgs = lalpulsar.FstatOptionalArgsDefaults
if self.minCoverFreq is None or self.maxCoverFreq is None:
fA = SFTCatalog.data[0].header.f0
numBins = SFTCatalog.data[0].numBins
fB = fA + (numBins-1)*SFTCatalog.data[0].header.deltaF
self.minCoverFreq = fA + 0.5
self.maxCoverFreq = fB - 0.5
self.FstatInput = lalpulsar.CreateFstatInput(SFTCatalog,
self.minCoverFreq,
self.maxCoverFreq,
dFreq,
ephems,
FstatOptionalArgs
)
logging.info('Initialising PulsarDoplerParams')
PulsarDopplerParams = lalpulsar.PulsarDopplerParams()
PulsarDopplerParams.refTime = self.tref
PulsarDopplerParams.Alpha = 1
PulsarDopplerParams.Delta = 1
PulsarDopplerParams.fkdot = np.array([0, 0, 0, 0, 0, 0, 0])
self.PulsarDopplerParams = PulsarDopplerParams
logging.info('Initialising FstatResults')
self.FstatResults = lalpulsar.FstatResults()
def run_computefstatistic_single_point(self, tstart, tend, F0, F1,
F2, Alpha, Delta):
""" Compute the F-stat fully-coherently at a single point """
numFreqBins = 1
self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0])
self.PulsarDopplerParams.Alpha = Alpha
self.PulsarDopplerParams.Delta = Delta
lalpulsar.ComputeFstat(self.FstatResults,
self.FstatInput,
self.PulsarDopplerParams,
numFreqBins,
self.whatToCompute
)
windowRange = lalpulsar.transientWindowRange_t()
windowRange.type = lalpulsar.TRANSIENT_RECTANGULAR
windowRange.t0 = int(tstart) # TYPE UINT4
windowRange.t0Band = 0
windowRange.dt0 = 1
windowRange.tau = int(tend - tstart) # TYPE UINT4
windowRange.tauBand = 0
windowRange.dtau = 1
useFReg = False
FS = lalpulsar.ComputeTransientFstatMap(
self.FstatResults.multiFatoms[0], windowRange, useFReg)
return 2*FS.F_mn.data[0][0]
class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
""" A semi-coherent glitch search """ A semi-coherent glitch search
This implements a basic `semi-coherent glitch F-stat in which the data This implements a basic `semi-coherent glitch F-stat in which the data
...@@ -214,7 +308,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass): ...@@ -214,7 +308,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass):
theta_post_glitch_at_glitch, tref - te) theta_post_glitch_at_glitch, tref - te)
twoFVal = self.run_computefstatistic_single_point( twoFVal = self.run_computefstatistic_single_point(
tref, ts, te, theta_at_tref[0], theta_at_tref[1], ts, te, theta_at_tref[0], theta_at_tref[1],
theta_at_tref[2], Alpha, Delta) theta_at_tref[2], Alpha, Delta)
twoFSum += twoFVal twoFSum += twoFVal
...@@ -234,95 +328,19 @@ class SemiCoherentGlitchSearch(BaseSearchClass): ...@@ -234,95 +328,19 @@ class SemiCoherentGlitchSearch(BaseSearchClass):
theta_post_glitch_at_glitch, tref - tglitch) theta_post_glitch_at_glitch, tref - tglitch)
twoFsegA = self.run_computefstatistic_single_point( twoFsegA = self.run_computefstatistic_single_point(
tref, self.tstart, tglitch, theta[0], theta[1], theta[2], Alpha, self.tstart, tglitch, theta[0], theta[1], theta[2], Alpha,
Delta) Delta)
if tglitch == self.tend: if tglitch == self.tend:
return twoFsegA return twoFsegA
twoFsegB = self.run_computefstatistic_single_point( twoFsegB = self.run_computefstatistic_single_point(
tref, tglitch, self.tend, theta_post_glitch[0], tglitch, self.tend, theta_post_glitch[0],
theta_post_glitch[1], theta_post_glitch[2], Alpha, theta_post_glitch[1], theta_post_glitch[2], Alpha,
Delta) Delta)
return twoFsegA + twoFsegB return twoFsegA + twoFsegB
def init_computefstatistic_single_point(self):
""" Initilisation step of run_computefstatistic for a single point """
logging.info('Initialising SFTCatalog')
constraints = lalpulsar.SFTConstraints()
if self.detector:
constraints.detector = self.detector
self.sft_filepath = self.sftdir+'/*_'+self.sftlabel+"*sft"
SFTCatalog = lalpulsar.SFTdataFind(self.sft_filepath, constraints)
names = list(set([d.header.name for d in SFTCatalog.data]))
logging.info('Loaded data from detectors {}'.format(names))
logging.info('Initialising ephems')
ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem)
logging.info('Initialising FstatInput')
dFreq = 0
self.whatToCompute = lalpulsar.FSTATQ_ATOMS_PER_DET
FstatOptionalArgs = lalpulsar.FstatOptionalArgsDefaults
if self.minCoverFreq is None or self.maxCoverFreq is None:
fA = SFTCatalog.data[0].header.f0
numBins = SFTCatalog.data[0].numBins
fB = fA + (numBins-1)*SFTCatalog.data[0].header.deltaF
self.minCoverFreq = fA + 0.5
self.maxCoverFreq = fB - 0.5
self.FstatInput = lalpulsar.CreateFstatInput(SFTCatalog,
self.minCoverFreq,
self.maxCoverFreq,
dFreq,
ephems,
FstatOptionalArgs
)
logging.info('Initialising PulsarDoplerParams')
PulsarDopplerParams = lalpulsar.PulsarDopplerParams()
PulsarDopplerParams.refTime = self.tref
PulsarDopplerParams.Alpha = 1
PulsarDopplerParams.Delta = 1
PulsarDopplerParams.fkdot = np.array([0, 0, 0, 0, 0, 0, 0])
self.PulsarDopplerParams = PulsarDopplerParams
logging.info('Initialising FstatResults')
self.FstatResults = lalpulsar.FstatResults()
def run_computefstatistic_single_point(self, tref, tstart, tend, F0, F1,
F2, Alpha, Delta):
""" Compute the F-stat fully-coherently at a single point """
numFreqBins = 1
self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0])
self.PulsarDopplerParams.Alpha = Alpha
self.PulsarDopplerParams.Delta = Delta
lalpulsar.ComputeFstat(self.FstatResults,
self.FstatInput,
self.PulsarDopplerParams,
numFreqBins,
self.whatToCompute
)
windowRange = lalpulsar.transientWindowRange_t()
windowRange.type = lalpulsar.TRANSIENT_RECTANGULAR
windowRange.t0 = int(tstart) # TYPE UINT4
windowRange.t0Band = 0
windowRange.dt0 = 1
windowRange.tau = int(tend - tstart) # TYPE UINT4
windowRange.tauBand = 0
windowRange.dtau = 1
useFReg = False
FS = lalpulsar.ComputeTransientFstatMap(self.FstatResults.multiFatoms[0],
windowRange,
useFReg)
return 2*FS.F_mn.data[0][0]
class MCMCGlitchSearch(BaseSearchClass): class MCMCGlitchSearch(BaseSearchClass):
""" MCMC search using the SemiCoherentGlitchSearch """ """ MCMC search using the SemiCoherentGlitchSearch """
......
...@@ -75,7 +75,7 @@ class TestBaseSearchClass(unittest.TestCase): ...@@ -75,7 +75,7 @@ class TestBaseSearchClass(unittest.TestCase):
rtol=1e-9, atol=1e-9)) rtol=1e-9, atol=1e-9))
class TestSemiCoherentGlitchSearch(unittest.TestCase): class TestComputeFstat(unittest.TestCase):
label = "Test" label = "Test"
outdir = 'TestData' outdir = 'TestData'
...@@ -84,12 +84,11 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase): ...@@ -84,12 +84,11 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase):
Writer.make_data() Writer.make_data()
predicted_FS = Writer.predict_fstat() predicted_FS = Writer.predict_fstat()
search = pyfstat.SemiCoherentGlitchSearch( search = pyfstat.ComputeFstat(tref=Writer.tref, sftlabel=Writer.label,
label=Writer.label, outdir=Writer.outdir, tref=Writer.tref, sftdir=Writer.outdir)
tstart=Writer.tstart, tend=Writer.tend) FS = search.run_computefstatistic_single_point(Writer.tref,
FS = search.run_computefstatistic_single_point(search.tref, Writer.tstart,
search.tstart, Writer.tend,
search.tend,
Writer.F0, Writer.F0,
Writer.F1, Writer.F1,
Writer.F2, Writer.F2,
...@@ -98,6 +97,11 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase): ...@@ -98,6 +97,11 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase):
print predicted_FS, FS print predicted_FS, FS
self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.1) self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.1)
class TestSemiCoherentGlitchSearch(unittest.TestCase):
label = "Test"
outdir = 'TestData'
def test_compute_nglitch_fstat(self): def test_compute_nglitch_fstat(self):
duration = 100*86400 duration = 100*86400
dtglitch = 100*43200 dtglitch = 100*43200
...@@ -131,7 +135,7 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase): ...@@ -131,7 +135,7 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase):
predicted_FS = (FSA + FSB) predicted_FS = (FSA + FSB)
print(predicted_FS, FS) print(predicted_FS, FS)
self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.1) self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3)
class TestMCMCGlitchSearch(unittest.TestCase): class TestMCMCGlitchSearch(unittest.TestCase):
...@@ -178,7 +182,7 @@ class TestMCMCGlitchSearch(unittest.TestCase): ...@@ -178,7 +182,7 @@ class TestMCMCGlitchSearch(unittest.TestCase):
print('Predicted twoF is {} while recovered is {}'.format( print('Predicted twoF is {} while recovered is {}'.format(
predicted_FS, FS)) predicted_FS, FS))
self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.1) self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3)
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment