Commit 51d107a0 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Refactoring the pyfstat code

Transforms the single pyfstat.py into a python module splitting the
relevant code into separate sub-files in pyfstat. This should result in
improved readability.
parent 9c875126
......@@ -40,11 +40,7 @@ are provided in the links.
### Dependencies
`pyfstat` makes use of a variety python modules listed as the
`imports` in the top of `pyfstat.py`. The first set are core modules (such as
`os`, `sys`) while the second set are external and need to be installed for
`pyfstat` to work properly. Please install the following widely available
modules:
`pyfstat` makes uses the following external python modules:
* [numpy](http://www.numpy.org/)
* [matplotlib](http://matplotlib.org/)
......
from __future__ import division
from .core import BaseSearchClass, ComputeFstat, Writer
from .mcmc_based_searches import *
from .grid_based_searches import *
""" The core tools used in pyfstat """
import os
import logging
import copy
import glob
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import scipy.special
import scipy.optimize
import lal
import lalpulsar
import helper_functions
tqdm = helper_functions.set_up_optional_tqdm()
helper_functions.set_up_matplotlib_defaults()
args = helper_functions.set_up_command_line_arguments()
earth_ephem, sun_ephem = helper_functions.set_up_ephemeris_configuration()
def read_par(label, outdir):
""" Read in a .par file, returns a dictionary of the values """
filename = '{}/{}.par'.format(outdir, label)
d = {}
with open(filename, 'r') as f:
for line in f:
if len(line.split('=')) > 1:
key, val = line.rstrip('\n').split(' = ')
key = key.strip()
d[key] = np.float64(eval(val.rstrip('; ')))
return d
class BaseSearchClass(object):
""" The base search class, provides general functions """
earth_ephem_default = earth_ephem
sun_ephem_default = sun_ephem
def add_log_file(self):
""" Log output to a file, requires class to have outdir and label """
logfilename = '{}/{}.log'.format(self.outdir, self.label)
fh = logging.FileHandler(logfilename)
fh.setLevel(logging.INFO)
fh.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)-8s: %(message)s',
datefmt='%y-%m-%d %H:%M'))
logging.getLogger().addHandler(fh)
def shift_matrix(self, n, dT):
""" Generate the shift matrix
Parameters
----------
n: int
The dimension of the shift-matrix to generate
dT: float
The time delta of the shift matrix
Returns
-------
m: array (n, n)
The shift matrix
"""
m = np.zeros((n, n))
factorial = np.math.factorial
for i in range(n):
for j in range(n):
if i == j:
m[i, j] = 1.0
elif i > j:
m[i, j] = 0.0
else:
if i == 0:
m[i, j] = 2*np.pi*float(dT)**(j-i) / factorial(j-i)
else:
m[i, j] = float(dT)**(j-i) / factorial(j-i)
return m
def shift_coefficients(self, theta, dT):
""" Shift a set of coefficients by dT
Parameters
----------
theta: array-like, shape (n,)
vector of the expansion coefficients to transform starting from the
lowest degree e.g [phi, F0, F1,...].
dT: float
difference between the two reference times as tref_new - tref_old.
Returns
-------
theta_new: array-like shape (n,)
vector of the coefficients as evaluate as the new reference time.
"""
n = len(theta)
m = self.shift_matrix(n, dT)
return np.dot(m, theta)
def calculate_thetas(self, theta, delta_thetas, tbounds, theta0_idx=0):
""" Calculates the set of coefficients for the post-glitch signal """
thetas = [theta]
for i, dt in enumerate(delta_thetas):
if i < theta0_idx:
pre_theta_at_ith_glitch = self.shift_coefficients(
thetas[0], tbounds[i+1] - self.tref)
post_theta_at_ith_glitch = pre_theta_at_ith_glitch - dt
thetas.insert(0, self.shift_coefficients(
post_theta_at_ith_glitch, self.tref - tbounds[i+1]))
elif i >= theta0_idx:
pre_theta_at_ith_glitch = self.shift_coefficients(
thetas[i], tbounds[i+1] - self.tref)
post_theta_at_ith_glitch = pre_theta_at_ith_glitch + dt
thetas.append(self.shift_coefficients(
post_theta_at_ith_glitch, self.tref - tbounds[i+1]))
return thetas
def generate_loudest(self):
params = read_par(self.label, self.outdir)
for key in ['Alpha', 'Delta', 'F0', 'F1']:
if key not in params:
params[key] = self.theta_prior[key]
cmd = ('lalapps_ComputeFstatistic_v2 -a {} -d {} -f {} -s {} -D "{}"'
' --refTime={} --outputLoudest="{}/{}.loudest" '
'--minStartTime={} --maxStartTime={}').format(
params['Alpha'], params['Delta'], params['F0'],
params['F1'], self.sftfilepath, params['tref'],
self.outdir, self.label, self.minStartTime,
self.maxStartTime)
subprocess.call([cmd], shell=True)
class ComputeFstat(object):
""" Base class providing interface to `lalpulsar.ComputeFstat` """
earth_ephem_default = earth_ephem
sun_ephem_default = sun_ephem
@helper_functions.initializer
def __init__(self, tref, sftfilepath=None, minStartTime=None,
maxStartTime=None, binary=False, transient=True, BSGL=False,
detector=None, minCoverFreq=None, maxCoverFreq=None,
earth_ephem=None, sun_ephem=None, injectSources=None
):
"""
Parameters
----------
tref: int
GPS seconds of the reference time.
sftfilepath: str
File patern to match SFTs
minStartTime, maxStartTime: float GPStime
Only use SFTs with timestemps starting from (including, excluding)
this epoch
binary: bool
If true, search of binary parameters.
transient: bool
If true, allow for the Fstat to be computed over a transient range.
BSGL: bool
If true, compute the BSGL rather than the twoF value.
detector: str
Two character reference to the data to use, specify None for no
contraint.
minCoverFreq, maxCoverFreq: float
The min and max cover frequency passed to CreateFstatInput, if
either is None the range of frequencies in the SFT less 1Hz is
used.
earth_ephem, sun_ephem: str
Paths of the two files containing positions of Earth and Sun,
respectively at evenly spaced times, as passed to CreateFstatInput.
If None defaults defined in BaseSearchClass will be used.
"""
if earth_ephem is None:
self.earth_ephem = self.earth_ephem_default
if sun_ephem is None:
self.sun_ephem = self.sun_ephem_default
self.init_computefstatistic_single_point()
def get_SFTCatalog(self):
if hasattr(self, 'SFTCatalog'):
return
logging.info('Initialising SFTCatalog')
constraints = lalpulsar.SFTConstraints()
if self.detector:
constraints.detector = self.detector
if self.minStartTime:
constraints.minStartTime = lal.LIGOTimeGPS(self.minStartTime)
if self.maxStartTime:
constraints.maxStartTime = lal.LIGOTimeGPS(self.maxStartTime)
logging.info('Loading data matching pattern {}'.format(
self.sftfilepath))
SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepath, constraints)
detector_names = list(set([d.header.name for d in SFTCatalog.data]))
self.detector_names = detector_names
SFT_timestamps = [d.header.epoch for d in SFTCatalog.data]
if args.quite is False and args.no_interactive is False:
try:
from bashplotlib.histogram import plot_hist
print('Data timestamps histogram:')
plot_hist(SFT_timestamps, height=5, bincount=50)
except IOError:
pass
if len(detector_names) == 0:
raise ValueError('No data loaded.')
logging.info('Loaded {} data files from detectors {}'.format(
len(SFT_timestamps), detector_names))
logging.info('Data spans from {} ({}) to {} ({})'.format(
int(SFT_timestamps[0]),
subprocess.check_output('lalapps_tconvert {}'.format(
int(SFT_timestamps[0])), shell=True).rstrip('\n'),
int(SFT_timestamps[-1]),
subprocess.check_output('lalapps_tconvert {}'.format(
int(SFT_timestamps[-1])), shell=True).rstrip('\n')))
self.SFTCatalog = SFTCatalog
def get_list_of_matching_sfts(self):
matches = glob.glob(self.sftfilepath)
if len(matches) > 0:
return matches
else:
raise IOError('No sfts found matching {}'.format(
self.sftfilepath))
def init_computefstatistic_single_point(self):
""" Initilisation step of run_computefstatistic for a single point """
self.get_SFTCatalog()
logging.info('Initialising ephems')
ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem)
logging.info('Initialising FstatInput')
dFreq = 0
if self.transient:
self.whatToCompute = lalpulsar.FSTATQ_ATOMS_PER_DET
else:
self.whatToCompute = lalpulsar.FSTATQ_2F
FstatOAs = lalpulsar.FstatOptionalArgs()
FstatOAs.randSeed = lalpulsar.FstatOptionalArgsDefaults.randSeed
FstatOAs.SSBprec = lalpulsar.FstatOptionalArgsDefaults.SSBprec
FstatOAs.Dterms = lalpulsar.FstatOptionalArgsDefaults.Dterms
FstatOAs.runningMedianWindow = lalpulsar.FstatOptionalArgsDefaults.runningMedianWindow
FstatOAs.FstatMethod = lalpulsar.FstatOptionalArgsDefaults.FstatMethod
FstatOAs.InjectSqrtSX = lalpulsar.FstatOptionalArgsDefaults.injectSqrtSX
FstatOAs.assumeSqrtSX = lalpulsar.FstatOptionalArgsDefaults.assumeSqrtSX
FstatOAs.prevInput = lalpulsar.FstatOptionalArgsDefaults.prevInput
FstatOAs.collectTiming = lalpulsar.FstatOptionalArgsDefaults.collectTiming
if hasattr(self, 'injectSource') and type(self.injectSources) == dict:
logging.info('Injecting source with params: {}'.format(
self.injectSources))
PPV = lalpulsar.CreatePulsarParamsVector(1)
PP = PPV.data[0]
PP.Amp.h0 = self.injectSources['h0']
PP.Amp.cosi = self.injectSources['cosi']
PP.Amp.phi0 = self.injectSources['phi0']
PP.Amp.psi = self.injectSources['psi']
PP.Doppler.Alpha = self.injectSources['Alpha']
PP.Doppler.Delta = self.injectSources['Delta']
PP.Doppler.fkdot = np.array(self.injectSources['fkdot'])
PP.Doppler.refTime = self.tref
if 't0' not in self.injectSources:
PP.Transient.type = lalpulsar.TRANSIENT_NONE
FstatOAs.injectSources = PPV
else:
FstatOAs.injectSources = lalpulsar.FstatOptionalArgsDefaults.injectSources
if self.minCoverFreq is None or self.maxCoverFreq is None:
fAs = [d.header.f0 for d in self.SFTCatalog.data]
fBs = [d.header.f0 + (d.numBins-1)*d.header.deltaF
for d in self.SFTCatalog.data]
self.minCoverFreq = np.min(fAs) + 0.5
self.maxCoverFreq = np.max(fBs) - 0.5
logging.info('Min/max cover freqs not provided, using '
'{} and {}, est. from SFTs'.format(
self.minCoverFreq, self.maxCoverFreq))
self.FstatInput = lalpulsar.CreateFstatInput(self.SFTCatalog,
self.minCoverFreq,
self.maxCoverFreq,
dFreq,
ephems,
FstatOAs
)
logging.info('Initialising PulsarDoplerParams')
PulsarDopplerParams = lalpulsar.PulsarDopplerParams()
PulsarDopplerParams.refTime = self.tref
PulsarDopplerParams.Alpha = 1
PulsarDopplerParams.Delta = 1
PulsarDopplerParams.fkdot = np.array([0, 0, 0, 0, 0, 0, 0])
self.PulsarDopplerParams = PulsarDopplerParams
logging.info('Initialising FstatResults')
self.FstatResults = lalpulsar.FstatResults()
if self.BSGL:
if len(self.detector_names) < 2:
raise ValueError("Can't use BSGL with single detector data")
else:
logging.info('Initialising BSGL')
# Tuning parameters - to be reviewed
numDetectors = 2
if hasattr(self, 'nsegs'):
p_val_threshold = 1e-6
Fstar0s = np.linspace(0, 1000, 10000)
p_vals = scipy.special.gammaincc(2*self.nsegs, Fstar0s)
Fstar0 = Fstar0s[np.argmin(np.abs(p_vals - p_val_threshold))]
if Fstar0 == Fstar0s[-1]:
raise ValueError('Max Fstar0 exceeded')
else:
Fstar0 = 15.
logging.info('Using Fstar0 of {:1.2f}'.format(Fstar0))
oLGX = np.zeros(10)
oLGX[:numDetectors] = 1./numDetectors
self.BSGLSetup = lalpulsar.CreateBSGLSetup(numDetectors,
Fstar0,
oLGX,
True,
1)
self.twoFX = np.zeros(10)
self.whatToCompute = (self.whatToCompute +
lalpulsar.FSTATQ_2F_PER_DET)
if self.transient:
logging.info('Initialising transient parameters')
self.windowRange = lalpulsar.transientWindowRange_t()
self.windowRange.type = lalpulsar.TRANSIENT_RECTANGULAR
self.windowRange.t0Band = 0
self.windowRange.dt0 = 1
self.windowRange.tauBand = 0
self.windowRange.dtau = 1
def compute_fullycoherent_det_stat_single_point(
self, F0, F1, F2, Alpha, Delta, asini=None, period=None, ecc=None,
tp=None, argp=None):
""" Compute the fully-coherent det. statistic at a single point """
return self.run_computefstatistic_single_point(
self.minStartTime, self.maxStartTime, F0, F1, F2, Alpha, Delta,
asini, period, ecc, tp, argp)
def run_computefstatistic_single_point(self, tstart, tend, F0, F1,
F2, Alpha, Delta, asini=None,
period=None, ecc=None, tp=None,
argp=None):
""" Returns twoF or ln(BSGL) fully-coherently at a single point """
self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0])
self.PulsarDopplerParams.Alpha = Alpha
self.PulsarDopplerParams.Delta = Delta
if self.binary:
self.PulsarDopplerParams.asini = asini
self.PulsarDopplerParams.period = period
self.PulsarDopplerParams.ecc = ecc
self.PulsarDopplerParams.tp = tp
self.PulsarDopplerParams.argp = argp
lalpulsar.ComputeFstat(self.FstatResults,
self.FstatInput,
self.PulsarDopplerParams,
1,
self.whatToCompute
)
if self.transient is False:
if self.BSGL is False:
return self.FstatResults.twoF[0]
twoF = np.float(self.FstatResults.twoF[0])
self.twoFX[0] = self.FstatResults.twoFPerDet(0)
self.twoFX[1] = self.FstatResults.twoFPerDet(1)
log10_BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX,
self.BSGLSetup)
return log10_BSGL/np.log10(np.exp(1))
self.windowRange.t0 = int(tstart) # TYPE UINT4
self.windowRange.tau = int(tend - tstart) # TYPE UINT4
FS = lalpulsar.ComputeTransientFstatMap(
self.FstatResults.multiFatoms[0], self.windowRange, False)
if self.BSGL is False:
return 2*FS.F_mn.data[0][0]
FstatResults_single = copy.copy(self.FstatResults)
FstatResults_single.lenth = 1
FstatResults_single.data = self.FstatResults.multiFatoms[0].data[0]
FS0 = lalpulsar.ComputeTransientFstatMap(
FstatResults_single.multiFatoms[0], self.windowRange, False)
FstatResults_single.data = self.FstatResults.multiFatoms[0].data[1]
FS1 = lalpulsar.ComputeTransientFstatMap(
FstatResults_single.multiFatoms[0], self.windowRange, False)
self.twoFX[0] = 2*FS0.F_mn.data[0][0]
self.twoFX[1] = 2*FS1.F_mn.data[0][0]
log10_BSGL = lalpulsar.ComputeBSGL(
2*FS.F_mn.data[0][0], self.twoFX, self.BSGLSetup)
return log10_BSGL/np.log10(np.exp(1))
def calculate_twoF_cumulative(self, F0, F1, F2, Alpha, Delta, asini=None,
period=None, ecc=None, tp=None, argp=None,
tstart=None, tend=None, npoints=1000,
minfraction=0.01, maxfraction=1):
""" Calculate the cumulative twoF along the obseration span """
duration = tend - tstart
tstart = tstart + minfraction*duration
taus = np.linspace(minfraction*duration, maxfraction*duration, npoints)
twoFs = []
if self.transient is False:
self.transient = True
self.init_computefstatistic_single_point()
for tau in taus:
twoFs.append(self.run_computefstatistic_single_point(
tstart=tstart, tend=tstart+tau, F0=F0, F1=F1, F2=F2,
Alpha=Alpha, Delta=Delta, asini=asini, period=period, ecc=ecc,
tp=tp, argp=argp))
return taus, np.array(twoFs)
def plot_twoF_cumulative(self, label, outdir, ax=None, c='k', savefig=True,
title=None, **kwargs):
taus, twoFs = self.calculate_twoF_cumulative(**kwargs)
if ax is None:
fig, ax = plt.subplots()
ax.plot(taus/86400., twoFs, label=label, color=c)
ax.set_xlabel(r'Days from $t_{{\rm start}}={:.0f}$'.format(
kwargs['tstart']))
if self.BSGL:
ax.set_ylabel(r'$\log_{10}(\mathrm{BSGL})_{\rm cumulative}$')
else:
ax.set_ylabel(r'$\widetilde{2\mathcal{F}}_{\rm cumulative}$')
ax.set_xlim(0, taus[-1]/86400)
if title:
ax.set_title(title)
if savefig:
plt.tight_layout()
plt.savefig('{}/{}_twoFcumulative.png'.format(outdir, label))
return taus, twoFs
else:
return ax
class SemiCoherentSearch(BaseSearchClass, ComputeFstat):
""" A semi-coherent search """
@helper_functions.initializer
def __init__(self, label, outdir, tref, nsegs=None, sftfilepath=None,
binary=False, BSGL=False, minStartTime=None,
maxStartTime=None, minCoverFreq=None, maxCoverFreq=None,
detector=None, earth_ephem=None, sun_ephem=None,
injectSources=None):
"""
Parameters
----------
label, outdir: str
A label and directory to read/write data from/to.
tref, minStartTime, maxStartTime: int
GPS seconds of the reference time, and start and end of the data.
nsegs: int
The (fixed) number of segments
sftfilepath: str
File patern to match SFTs
For all other parameters, see pyfstat.ComputeFStat.
"""
self.fs_file_name = "{}/{}_FS.dat".format(self.outdir, self.label)
if self.earth_ephem is None:
self.earth_ephem = self.earth_ephem_default
if self.sun_ephem is None:
self.sun_ephem = self.sun_ephem_default
self.transient = True
self.init_computefstatistic_single_point()
self.init_semicoherent_parameters()
def init_semicoherent_parameters(self):
logging.info(('Initialising semicoherent parameters from {} to {} in'
' {} segments').format(
self.minStartTime, self.maxStartTime, self.nsegs))
self.transient = True
self.whatToCompute = lalpulsar.FSTATQ_2F+lalpulsar.FSTATQ_ATOMS_PER_DET
self.tboundaries = np.linspace(self.minStartTime, self.maxStartTime,
self.nsegs+1)
def run_semi_coherent_computefstatistic_single_point(
self, F0, F1, F2, Alpha, Delta, asini=None,
period=None, ecc=None, tp=None, argp=None):
""" Returns twoF or ln(BSGL) semi-coherently at a single point """
self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0])
self.PulsarDopplerParams.Alpha = Alpha
self.PulsarDopplerParams.Delta = Delta
if self.binary:
self.PulsarDopplerParams.asini = asini
self.PulsarDopplerParams.period = period
self.PulsarDopplerParams.ecc = ecc
self.PulsarDopplerParams.tp = tp
self.PulsarDopplerParams.argp = argp
lalpulsar.ComputeFstat(self.FstatResults,
self.FstatInput,
self.PulsarDopplerParams,
1,
self.whatToCompute
)
if self.transient is False:
if self.BSGL is False:
return self.FstatResults.twoF[0]
twoF = np.float(self.FstatResults.twoF[0])
self.twoFX[0] = self.FstatResults.twoFPerDet(0)
self.twoFX[1] = self.FstatResults.twoFPerDet(1)
log10_BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX,
self.BSGLSetup)
return log10_BSGL/np.log10(np.exp(1))
detStat = 0
for tstart, tend in zip(self.tboundaries[:-1], self.tboundaries[1:]):
self.windowRange.t0 = int(tstart) # TYPE UINT4
self.windowRange.tau = int(tend - tstart) # TYPE UINT4
FS = lalpulsar.ComputeTransientFstatMap(
self.FstatResults.multiFatoms[0], self.windowRange, False)
if self.BSGL is False:
detStat += 2*FS.F_mn.data[0][0]
continue
FstatResults_single = copy.copy(self.FstatResults)
FstatResults_single.lenth = 1
FstatResults_single.data = self.FstatResults.multiFatoms[0].data[0]
FS0 = lalpulsar.ComputeTransientFstatMap(
FstatResults_single.multiFatoms[0], self.windowRange, False)
FstatResults_single.data = self.FstatResults.multiFatoms[0].data[1]
FS1 = lalpulsar.ComputeTransientFstatMap(
FstatResults_single.multiFatoms[0], self.windowRange, False)
self.twoFX[0] = 2*FS0.F_mn.data[0][0]
self.twoFX[1] = 2*FS1.F_mn.data[0][0]
log10_BSGL = lalpulsar.ComputeBSGL(
2*FS.F_mn.data[0][0], self.twoFX, self.BSGLSetup)
detStat += log10_BSGL/np.log10(np.exp(1))
return detStat
class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
""" A semi-coherent glitch search
This implements a basic `semi-coherent glitch F-stat in which the data
is divided into segments either side of the proposed glitches and the
fully-coherent F-stat in each segment is summed to give the semi-coherent
F-stat
"""
@helper_functions.initializer
def __init__(self, label, outdir, tref, minStartTime, maxStartTime,
nglitch=0, sftfilepath=None, theta0_idx=0, BSGL=False,
minCoverFreq=None, maxCoverFreq=None,
detector=None, earth_ephem=None, sun_ephem=None):
"""
Parameters
----------
label, outdir: str
A label and directory to read/write data from/to.
tref, minStartTime, maxStartTime: int
GPS seconds of the reference time, and start and end of the data.
nglitch: int
The (fixed) number of glitches; this can zero, but occasionally
this causes issue (in which case just use ComputeFstat).
sftfilepath: str