Commit 8ffae7e2 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adds in-memory injection capabilities

parent 5d115632
...@@ -151,9 +151,9 @@ class ComputeFstat(object): ...@@ -151,9 +151,9 @@ class ComputeFstat(object):
@helper_functions.initializer @helper_functions.initializer
def __init__(self, tref, sftfilepath=None, minStartTime=None, def __init__(self, tref, sftfilepath=None, minStartTime=None,
maxStartTime=None, binary=False, transient=True, BSGL=False, maxStartTime=None, binary=False, transient=True, BSGL=False,
detector=None, minCoverFreq=None, maxCoverFreq=None, detectors=None, minCoverFreq=None, maxCoverFreq=None,
earth_ephem=None, sun_ephem=None, injectSources=None, earth_ephem=None, sun_ephem=None, injectSources=None,
assumeSqrtSX=None): injectSqrtSX=None, assumeSqrtSX=None):
""" """
Parameters Parameters
---------- ----------
...@@ -170,9 +170,9 @@ class ComputeFstat(object): ...@@ -170,9 +170,9 @@ class ComputeFstat(object):
If true, allow for the Fstat to be computed over a transient range. If true, allow for the Fstat to be computed over a transient range.
BSGL: bool BSGL: bool
If true, compute the BSGL rather than the twoF value. If true, compute the BSGL rather than the twoF value.
detector: str detectors: str
Two character reference to the data to use, specify None for no Two character reference to the data to use, specify None for no
contraint. contraint. If multiple-separate by comma.
minCoverFreq, maxCoverFreq: float minCoverFreq, maxCoverFreq: float
The min and max cover frequency passed to CreateFstatInput, if The min and max cover frequency passed to CreateFstatInput, if
either is None the range of frequencies in the SFT less 1Hz is either is None the range of frequencies in the SFT less 1Hz is
...@@ -181,6 +181,11 @@ class ComputeFstat(object): ...@@ -181,6 +181,11 @@ class ComputeFstat(object):
Paths of the two files containing positions of Earth and Sun, Paths of the two files containing positions of Earth and Sun,
respectively at evenly spaced times, as passed to CreateFstatInput. respectively at evenly spaced times, as passed to CreateFstatInput.
If None defaults defined in BaseSearchClass will be used. If None defaults defined in BaseSearchClass will be used.
injectSources: dict or str
Either a dictionary of the values to inject, or a string pointing
to the .cff file to inject
injectSqrtSX:
Not yet implemented
assumeSqrtSX: float assumeSqrtSX: float
Don't estimate noise-floors but assume (stationary) per-IFO Don't estimate noise-floors but assume (stationary) per-IFO
sqrt{SX} (if single value: use for all IFOs). If signal only, sqrt{SX} (if single value: use for all IFOs). If signal only,
...@@ -198,10 +203,32 @@ class ComputeFstat(object): ...@@ -198,10 +203,32 @@ class ComputeFstat(object):
def get_SFTCatalog(self): def get_SFTCatalog(self):
if hasattr(self, 'SFTCatalog'): if hasattr(self, 'SFTCatalog'):
return return
if self.sftfilepath is None:
for k in ['minStartTime', 'maxStartTime', 'detectors']:
if getattr(self, k) is None:
raise ValueError('You must provide "{}" to injectSources'
.format(k))
C1 = getattr(self, 'injectSources', None) is None
C2 = getattr(self, 'injectSqrtSX', None) is None
if C1 and C2:
raise ValueError('You must specify either one of injectSources'
' or injectSqrtSX')
SFTCatalog = lalpulsar.SFTCatalog()
Tsft = 1800
Toverlap = 0
Tspan = self.maxStartTime - self.minStartTime
detNames = lal.CreateStringVector(
*[d for d in self.detectors.split(',')])
multiTimestamps = lalpulsar.MakeMultiTimestamps(
self.minStartTime, Tspan, Tsft, Toverlap, detNames.length)
SFTCatalog = lalpulsar.MultiAddToFakeSFTCatalog(
SFTCatalog, detNames, multiTimestamps)
return SFTCatalog
logging.info('Initialising SFTCatalog') logging.info('Initialising SFTCatalog')
constraints = lalpulsar.SFTConstraints() constraints = lalpulsar.SFTConstraints()
if self.detector: if self.detectors:
constraints.detector = self.detector constraints.detectors = self.detectors
if self.minStartTime: if self.minStartTime:
constraints.minStartTime = lal.LIGOTimeGPS(self.minStartTime) constraints.minStartTime = lal.LIGOTimeGPS(self.minStartTime)
if self.maxStartTime: if self.maxStartTime:
...@@ -231,12 +258,12 @@ class ComputeFstat(object): ...@@ -231,12 +258,12 @@ class ComputeFstat(object):
int(SFT_timestamps[-1]), int(SFT_timestamps[-1]),
subprocess.check_output('lalapps_tconvert {}'.format( subprocess.check_output('lalapps_tconvert {}'.format(
int(SFT_timestamps[-1])), shell=True).rstrip('\n'))) int(SFT_timestamps[-1])), shell=True).rstrip('\n')))
self.SFTCatalog = SFTCatalog return SFTCatalog
def init_computefstatistic_single_point(self): def init_computefstatistic_single_point(self):
""" Initilisation step of run_computefstatistic for a single point """ """ Initilisation step of run_computefstatistic for a single point """
self.get_SFTCatalog() SFTCatalog = self.get_SFTCatalog()
logging.info('Initialising ephems') logging.info('Initialising ephems')
ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem) ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem)
...@@ -254,7 +281,6 @@ class ComputeFstat(object): ...@@ -254,7 +281,6 @@ class ComputeFstat(object):
FstatOAs.Dterms = lalpulsar.FstatOptionalArgsDefaults.Dterms FstatOAs.Dterms = lalpulsar.FstatOptionalArgsDefaults.Dterms
FstatOAs.runningMedianWindow = lalpulsar.FstatOptionalArgsDefaults.runningMedianWindow FstatOAs.runningMedianWindow = lalpulsar.FstatOptionalArgsDefaults.runningMedianWindow
FstatOAs.FstatMethod = lalpulsar.FstatOptionalArgsDefaults.FstatMethod FstatOAs.FstatMethod = lalpulsar.FstatOptionalArgsDefaults.FstatMethod
FstatOAs.InjectSqrtSX = lalpulsar.FstatOptionalArgsDefaults.injectSqrtSX
if self.assumeSqrtSX is None: if self.assumeSqrtSX is None:
FstatOAs.assumeSqrtSX = lalpulsar.FstatOptionalArgsDefaults.assumeSqrtSX FstatOAs.assumeSqrtSX = lalpulsar.FstatOptionalArgsDefaults.assumeSqrtSX
else: else:
...@@ -282,20 +308,28 @@ class ComputeFstat(object): ...@@ -282,20 +308,28 @@ class ComputeFstat(object):
if 't0' not in self.injectSources: if 't0' not in self.injectSources:
PP.Transient.type = lalpulsar.TRANSIENT_NONE PP.Transient.type = lalpulsar.TRANSIENT_NONE
FstatOAs.injectSources = PPV FstatOAs.injectSources = PPV
if hasattr(self, 'injectSources') and type(self.injectSources) == str:
logging.info('Injecting source from param file: {}'.format(
self.injectSources))
PPV = lalpulsar.PulsarParamsFromFile(self.injectSources, self.tref)
FstatOAs.injectSources = PPV
else: else:
FstatOAs.injectSources = lalpulsar.FstatOptionalArgsDefaults.injectSources FstatOAs.injectSources = lalpulsar.FstatOptionalArgsDefaults.injectSources
if hasattr(self, 'injectSqrtSX') and self.injectSqrtSX is not None:
raise ValueError('injectSqrtSX not implemented')
else:
FstatOAs.InjectSqrtSX = lalpulsar.FstatOptionalArgsDefaults.injectSqrtSX
if self.minCoverFreq is None or self.maxCoverFreq is None: if self.minCoverFreq is None or self.maxCoverFreq is None:
fAs = [d.header.f0 for d in self.SFTCatalog.data] fAs = [d.header.f0 for d in SFTCatalog.data]
fBs = [d.header.f0 + (d.numBins-1)*d.header.deltaF fBs = [d.header.f0 + (d.numBins-1)*d.header.deltaF
for d in self.SFTCatalog.data] for d in SFTCatalog.data]
self.minCoverFreq = np.min(fAs) + 0.5 self.minCoverFreq = np.min(fAs) + 0.5
self.maxCoverFreq = np.max(fBs) - 0.5 self.maxCoverFreq = np.max(fBs) - 0.5
logging.info('Min/max cover freqs not provided, using ' logging.info('Min/max cover freqs not provided, using '
'{} and {}, est. from SFTs'.format( '{} and {}, est. from SFTs'.format(
self.minCoverFreq, self.maxCoverFreq)) self.minCoverFreq, self.maxCoverFreq))
self.FstatInput = lalpulsar.CreateFstatInput(self.SFTCatalog, self.FstatInput = lalpulsar.CreateFstatInput(SFTCatalog,
self.minCoverFreq, self.minCoverFreq,
self.maxCoverFreq, self.maxCoverFreq,
dFreq, dFreq,
...@@ -316,7 +350,7 @@ class ComputeFstat(object): ...@@ -316,7 +350,7 @@ class ComputeFstat(object):
if self.BSGL: if self.BSGL:
if len(self.detector_names) < 2: if len(self.detector_names) < 2:
raise ValueError("Can't use BSGL with single detector data") raise ValueError("Can't use BSGL with single detectors data")
else: else:
logging.info('Initialising BSGL') logging.info('Initialising BSGL')
...@@ -471,7 +505,7 @@ class SemiCoherentSearch(BaseSearchClass, ComputeFstat): ...@@ -471,7 +505,7 @@ class SemiCoherentSearch(BaseSearchClass, ComputeFstat):
def __init__(self, label, outdir, tref, nsegs=None, sftfilepath=None, def __init__(self, label, outdir, tref, nsegs=None, sftfilepath=None,
binary=False, BSGL=False, minStartTime=None, binary=False, BSGL=False, minStartTime=None,
maxStartTime=None, minCoverFreq=None, maxCoverFreq=None, maxStartTime=None, minCoverFreq=None, maxCoverFreq=None,
detector=None, earth_ephem=None, sun_ephem=None, detectors=None, earth_ephem=None, sun_ephem=None,
injectSources=None, assumeSqrtSX=None): injectSources=None, assumeSqrtSX=None):
""" """
Parameters Parameters
...@@ -583,7 +617,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): ...@@ -583,7 +617,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
def __init__(self, label, outdir, tref, minStartTime, maxStartTime, def __init__(self, label, outdir, tref, minStartTime, maxStartTime,
nglitch=0, sftfilepath=None, theta0_idx=0, BSGL=False, nglitch=0, sftfilepath=None, theta0_idx=0, BSGL=False,
minCoverFreq=None, maxCoverFreq=None, assumeSqrtSX=None, minCoverFreq=None, maxCoverFreq=None, assumeSqrtSX=None,
detector=None, earth_ephem=None, sun_ephem=None): detectors=None, earth_ephem=None, sun_ephem=None):
""" """
Parameters Parameters
---------- ----------
...@@ -682,7 +716,7 @@ class Writer(BaseSearchClass): ...@@ -682,7 +716,7 @@ class Writer(BaseSearchClass):
dtglitch=None, delta_phi=0, delta_F0=0, delta_F1=0, dtglitch=None, delta_phi=0, delta_F0=0, delta_F1=0,
delta_F2=0, tref=None, F0=30, F1=1e-10, F2=0, Alpha=5e-3, delta_F2=0, tref=None, F0=30, F1=1e-10, F2=0, Alpha=5e-3,
Delta=6e-2, h0=0.1, cosi=0.0, psi=0.0, phi=0, Tsft=1800, Delta=6e-2, h0=0.1, cosi=0.0, psi=0.0, phi=0, Tsft=1800,
outdir=".", sqrtSX=1, Band=4, detector='H1', outdir=".", sqrtSX=1, Band=4, detectors='H1',
minStartTime=None, maxStartTime=None, add_noise=True): minStartTime=None, maxStartTime=None, add_noise=True):
""" """
Parameters Parameters
...@@ -888,7 +922,7 @@ transientTauDays={:1.3f}\n""") ...@@ -888,7 +922,7 @@ transientTauDays={:1.3f}\n""")
cl.append('--outSingleSFT=TRUE') cl.append('--outSingleSFT=TRUE')
cl.append('--outSFTdir="{}"'.format(self.outdir)) cl.append('--outSFTdir="{}"'.format(self.outdir))
cl.append('--outLabel="{}"'.format(self.label)) cl.append('--outLabel="{}"'.format(self.label))
cl.append('--IFOs="{}"'.format(self.detector)) cl.append('--IFOs="{}"'.format(self.detectors))
if self.add_noise: if self.add_noise:
cl.append('--sqrtSX="{}"'.format(self.sqrtSX)) cl.append('--sqrtSX="{}"'.format(self.sqrtSX))
if self.minStartTime is None: if self.minStartTime is None:
......
...@@ -28,7 +28,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -28,7 +28,7 @@ class MCMCSearch(BaseSearchClass):
nwalkers=100, ntemps=1, log10temperature_min=-5, nwalkers=100, ntemps=1, log10temperature_min=-5,
theta_initial=None, scatter_val=1e-10, theta_initial=None, scatter_val=1e-10,
binary=False, BSGL=False, minCoverFreq=None, binary=False, BSGL=False, minCoverFreq=None,
maxCoverFreq=None, detector=None, earth_ephem=None, maxCoverFreq=None, detectors=None, earth_ephem=None,
sun_ephem=None, injectSources=None, assumeSqrtSX=None): sun_ephem=None, injectSources=None, assumeSqrtSX=None):
""" """
Parameters Parameters
...@@ -60,7 +60,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -60,7 +60,7 @@ class MCMCSearch(BaseSearchClass):
generated from np.logspace(0, log10temperature_min, ntemps). generated from np.logspace(0, log10temperature_min, ntemps).
binary: Bool binary: Bool
If true, search over binary parameters If true, search over binary parameters
detector: str detectors: str
Two character reference to the data to use, specify None for no Two character reference to the data to use, specify None for no
contraint. contraint.
minCoverFreq, maxCoverFreq: float minCoverFreq, maxCoverFreq: float
...@@ -112,7 +112,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -112,7 +112,7 @@ class MCMCSearch(BaseSearchClass):
tref=self.tref, sftfilepath=self.sftfilepath, tref=self.tref, sftfilepath=self.sftfilepath,
minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq, minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
detector=self.detector, BSGL=self.BSGL, transient=False, detectors=self.detectors, BSGL=self.BSGL, transient=False,
minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
binary=self.binary, injectSources=self.injectSources, binary=self.binary, injectSources=self.injectSources,
assumeSqrtSX=self.assumeSqrtSX) assumeSqrtSX=self.assumeSqrtSX)
...@@ -1016,7 +1016,7 @@ class MCMCGlitchSearch(MCMCSearch): ...@@ -1016,7 +1016,7 @@ class MCMCGlitchSearch(MCMCSearch):
minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100], minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100],
nwalkers=100, ntemps=1, log10temperature_min=-5, nwalkers=100, ntemps=1, log10temperature_min=-5,
theta_initial=None, scatter_val=1e-10, dtglitchmin=1*86400, theta_initial=None, scatter_val=1e-10, dtglitchmin=1*86400,
theta0_idx=0, detector=None, BSGL=False, minCoverFreq=None, theta0_idx=0, detectors=None, BSGL=False, minCoverFreq=None,
maxCoverFreq=None, earth_ephem=None, sun_ephem=None): maxCoverFreq=None, earth_ephem=None, sun_ephem=None):
""" """
Parameters Parameters
...@@ -1060,7 +1060,7 @@ class MCMCGlitchSearch(MCMCSearch): ...@@ -1060,7 +1060,7 @@ class MCMCGlitchSearch(MCMCSearch):
Index (zero-based) of which segment the theta refers to - uyseful Index (zero-based) of which segment the theta refers to - uyseful
if providing a tight prior on theta to allow the signal to jump if providing a tight prior on theta to allow the signal to jump
too theta (and not just from) too theta (and not just from)
detector: str detectors: str
Two character reference to the data to use, specify None for no Two character reference to the data to use, specify None for no
contraint. contraint.
minCoverFreq, maxCoverFreq: float minCoverFreq, maxCoverFreq: float
...@@ -1104,7 +1104,7 @@ class MCMCGlitchSearch(MCMCSearch): ...@@ -1104,7 +1104,7 @@ class MCMCGlitchSearch(MCMCSearch):
tref=self.tref, minStartTime=self.minStartTime, tref=self.tref, minStartTime=self.minStartTime,
maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq, maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem, maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
sun_ephem=self.sun_ephem, detector=self.detector, BSGL=self.BSGL, sun_ephem=self.sun_ephem, detectors=self.detectors, BSGL=self.BSGL,
nglitch=self.nglitch, theta0_idx=self.theta0_idx) nglitch=self.nglitch, theta0_idx=self.theta0_idx)
def logp(self, theta_vals, theta_prior, theta_keys, search): def logp(self, theta_vals, theta_prior, theta_keys, search):
...@@ -1263,10 +1263,10 @@ class MCMCGlitchSearch(MCMCSearch): ...@@ -1263,10 +1263,10 @@ class MCMCGlitchSearch(MCMCSearch):
class MCMCSemiCoherentSearch(MCMCSearch): class MCMCSemiCoherentSearch(MCMCSearch):
""" MCMC search for a signal using the semi-coherent ComputeFstat """ """ MCMC search for a signal using the semi-coherent ComputeFstat """
@helper_functions.initializer @helper_functions.initializer
def __init__(self, label, outdir, sftfilepath, theta_prior, tref, def __init__(self, label, outdir, theta_prior, tref, sftfilepath=None,
nsegs=None, nsteps=[100, 100, 100], nwalkers=100, binary=False, nsegs=None, nsteps=[100, 100, 100], nwalkers=100, binary=False,
ntemps=1, log10temperature_min=-5, theta_initial=None, ntemps=1, log10temperature_min=-5, theta_initial=None,
scatter_val=1e-10, detector=None, BSGL=False, scatter_val=1e-10, detectors=None, BSGL=False,
minStartTime=None, maxStartTime=None, minCoverFreq=None, minStartTime=None, maxStartTime=None, minCoverFreq=None,
maxCoverFreq=None, earth_ephem=None, sun_ephem=None, maxCoverFreq=None, earth_ephem=None, sun_ephem=None,
injectSources=None, assumeSqrtSX=None): injectSources=None, assumeSqrtSX=None):
...@@ -1304,7 +1304,7 @@ class MCMCSemiCoherentSearch(MCMCSearch): ...@@ -1304,7 +1304,7 @@ class MCMCSemiCoherentSearch(MCMCSearch):
nsegs=self.nsegs, sftfilepath=self.sftfilepath, binary=self.binary, nsegs=self.nsegs, sftfilepath=self.sftfilepath, binary=self.binary,
BSGL=self.BSGL, minStartTime=self.minStartTime, BSGL=self.BSGL, minStartTime=self.minStartTime,
maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq, maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
maxCoverFreq=self.maxCoverFreq, detector=self.detector, maxCoverFreq=self.maxCoverFreq, detectors=self.detectors,
earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
injectSources=self.injectSources, assumeSqrtSX=self.assumeSqrtSX) injectSources=self.injectSources, assumeSqrtSX=self.assumeSqrtSX)
...@@ -1641,7 +1641,7 @@ class MCMCTransientSearch(MCMCSearch): ...@@ -1641,7 +1641,7 @@ class MCMCTransientSearch(MCMCSearch):
tref=self.tref, sftfilepath=self.sftfilepath, tref=self.tref, sftfilepath=self.sftfilepath,
minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq, minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem, earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
detector=self.detector, transient=True, detectors=self.detectors, transient=True,
minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
BSGL=self.BSGL, binary=self.binary) BSGL=self.BSGL, binary=self.binary)
......
...@@ -123,6 +123,29 @@ class TestComputeFstat(Test): ...@@ -123,6 +123,29 @@ class TestComputeFstat(Test):
print predicted_FS, FS print predicted_FS, FS
self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.2) self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.2)
def test_injectSources_from_file(self):
Writer = pyfstat.Writer(self.label, outdir=outdir, add_noise=False)
Writer.make_cff()
injectSources = Writer.config_file_name
search = pyfstat.ComputeFstat(
tref=Writer.tref, assumeSqrtSX=1, injectSources=injectSources,
minCoverFreq=28, maxCoverFreq=32, minStartTime=Writer.tstart,
maxStartTime=Writer.tstart+Writer.duration,
detectors=Writer.detectors)
FS = search.run_computefstatistic_single_point(Writer.tstart,
Writer.tend,
Writer.F0,
Writer.F1,
Writer.F2,
Writer.Alpha,
Writer.Delta)
Writer.make_data()
predicted_FS = Writer.predict_fstat()
print predicted_FS, FS
self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.2)
class TestSemiCoherentGlitchSearch(Test): class TestSemiCoherentGlitchSearch(Test):
label = "Test" label = "Test"
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment