Commit 51d107a0 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Refactoring the pyfstat code

Transforms the single pyfstat.py into a python module splitting the
relevant code into separate sub-files in pyfstat. This should result in
improved readability.
parent 9c875126
......@@ -40,11 +40,7 @@ are provided in the links.
### Dependencies
`pyfstat` makes use of a variety python modules listed as the
`imports` in the top of `pyfstat.py`. The first set are core modules (such as
`os`, `sys`) while the second set are external and need to be installed for
`pyfstat` to work properly. Please install the following widely available
modules:
`pyfstat` makes uses the following external python modules:
* [numpy](http://www.numpy.org/)
* [matplotlib](http://matplotlib.org/)
......
from __future__ import division
from .core import BaseSearchClass, ComputeFstat, Writer
from .mcmc_based_searches import *
from .grid_based_searches import *
This diff is collapsed.
""" Searches using grid-based methods """
import os
import logging
import itertools
from collections import OrderedDict
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import helper_functions
from core import BaseSearchClass, ComputeFstat, SemiCoherentGlitchSearch
from core import tqdm, args, earth_ephem, sun_ephem
class GridSearch(BaseSearchClass):
""" Gridded search using ComputeFstat """
@helper_functions.initializer
def __init__(self, label, outdir, sftfilepath, F0s=[0], F1s=[0], F2s=[0],
Alphas=[0], Deltas=[0], tref=None, minStartTime=None,
maxStartTime=None, BSGL=False, minCoverFreq=None,
maxCoverFreq=None, earth_ephem=None, sun_ephem=None,
detector=None):
"""
Parameters
----------
label, outdir: str
A label and directory to read/write data from/to
sftfilepath: str
File patern to match SFTs
F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple
Length 3 tuple describing the grid for each parameter, e.g
[F0min, F0max, dF0], for a fixed value simply give [F0].
tref, minStartTime, maxStartTime: int
GPS seconds of the reference time, start time and end time
For all other parameters, see `pyfstat.ComputeFStat` for details
"""
if earth_ephem is None:
self.earth_ephem = self.earth_ephem_default
if sun_ephem is None:
self.sun_ephem = self.sun_ephem_default
if os.path.isdir(outdir) is False:
os.mkdir(outdir)
self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label)
self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
def inititate_search_object(self):
logging.info('Setting up search object')
self.search = ComputeFstat(
tref=self.tref, sftfilepath=self.sftfilepath,
minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
detector=self.detector, transient=False,
minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
BSGL=self.BSGL)
def get_array_from_tuple(self, x):
if len(x) == 1:
return np.array(x)
else:
return np.arange(x[0], x[1]*(1+1e-15), x[2])
def get_input_data_array(self):
arrays = []
for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s,
self.Alphas, self.Deltas):
arrays.append(self.get_array_from_tuple(tup))
input_data = []
for vals in itertools.product(*arrays):
input_data.append(vals)
self.arrays = arrays
self.input_data = np.array(input_data)
def check_old_data_is_okay_to_use(self):
if args.clean:
return False
if os.path.isfile(self.out_file) is False:
logging.info('No old data found, continuing with grid search')
return False
data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
if np.all(data[:, 0:-1] == self.input_data):
logging.info(
'Old data found with matching input, no search performed')
return data
else:
logging.info(
'Old data found, input differs, continuing with grid search')
return False
def run(self, return_data=False):
self.get_input_data_array()
old_data = self.check_old_data_is_okay_to_use()
if old_data is not False:
self.data = old_data
return
self.inititate_search_object()
logging.info('Total number of grid points is {}'.format(
len(self.input_data)))
data = []
for vals in tqdm(self.input_data):
FS = self.search.run_computefstatistic_single_point(*vals)
data.append(list(vals) + [FS])
data = np.array(data)
if return_data:
return data
else:
logging.info('Saving data to {}'.format(self.out_file))
np.savetxt(self.out_file, data, delimiter=' ')
self.data = data
def convert_F0_to_mismatch(self, F0, F0hat, Tseg):
DeltaF0 = F0[1] - F0[0]
m_spacing = (np.pi*Tseg*DeltaF0)**2 / 12.
N = len(F0)
return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing)
def convert_F1_to_mismatch(self, F1, F1hat, Tseg):
DeltaF1 = F1[1] - F1[0]
m_spacing = (np.pi*Tseg**2*DeltaF1)**2 / 720.
N = len(F1)
return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing)
def add_mismatch_to_ax(self, ax, x, y, xkey, ykey, xhat, yhat, Tseg):
axX = ax.twiny()
axX.zorder = -10
axY = ax.twinx()
axY.zorder = -10
if xkey == 'F0':
m = self.convert_F0_to_mismatch(x, xhat, Tseg)
axX.set_xlim(m[0], m[-1])
if ykey == 'F1':
m = self.convert_F1_to_mismatch(y, yhat, Tseg)
axY.set_ylim(m[0], m[-1])
def plot_1D(self, xkey):
fig, ax = plt.subplots()
xidx = self.keys.index(xkey)
x = np.unique(self.data[:, xidx])
z = self.data[:, -1]
plt.plot(x, z)
fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None,
add_mismatch=None, xN=None, yN=None, flat_keys=[],
rel_flat_idxs=[], flatten_method=np.max,
predicted_twoF=None, cm=None, cbarkwargs={}):
""" Plots a 2D grid of 2F values
Parameters
----------
add_mismatch: tuple (xhat, yhat, Tseg)
If not None, add a secondary axis with the metric mismatch from the
point xhat, yhat with duration Tseg
flatten_method: np.max
Function to use in flattening the flat_keys
"""
if ax is None:
fig, ax = plt.subplots()
xidx = self.keys.index(xkey)
yidx = self.keys.index(ykey)
flat_idxs = [self.keys.index(k) for k in flat_keys]
x = np.unique(self.data[:, xidx])
y = np.unique(self.data[:, yidx])
flat_vals = [np.unique(self.data[:, j]) for j in flat_idxs]
z = self.data[:, -1]
Y, X = np.meshgrid(y, x)
shape = [len(x), len(y)] + [len(v) for v in flat_vals]
Z = z.reshape(shape)
if len(rel_flat_idxs) > 0:
Z = flatten_method(Z, axis=tuple(rel_flat_idxs))
if predicted_twoF:
Z = (predicted_twoF - Z) / (predicted_twoF + 4)
if cm is None:
cm = plt.cm.viridis_r
else:
if cm is None:
cm = plt.cm.viridis
pax = ax.pcolormesh(X, Y, Z, cmap=cm, vmin=vmin, vmax=vmax)
cb = plt.colorbar(pax, ax=ax, **cbarkwargs)
cb.set_label('$2\mathcal{F}$')
if add_mismatch:
self.add_mismatch_to_ax(ax, x, y, xkey, ykey, *add_mismatch)
ax.set_xlim(x[0], x[-1])
ax.set_ylim(y[0], y[-1])
labels = {'F0': '$f$', 'F1': '$\dot{f}$'}
ax.set_xlabel(labels[xkey])
ax.set_ylabel(labels[ykey])
if xN:
ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(xN))
if yN:
ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(yN))
if save:
fig.tight_layout()
fig.savefig('{}/{}_2D.png'.format(self.outdir, self.label))
else:
return ax
def get_max_twoF(self):
twoF = self.data[:, -1]
idx = np.argmax(twoF)
v = self.data[idx, :]
d = OrderedDict(minStartTime=v[0], maxStartTime=v[1], F0=v[2], F1=v[3],
F2=v[4], Alpha=v[5], Delta=v[6], twoF=v[7])
return d
def print_max_twoF(self):
d = self.get_max_twoF()
print('Max twoF values for {}:'.format(self.label))
for k, v in d.iteritems():
print(' {}={}'.format(k, v))
class GridGlitchSearch(GridSearch):
""" Grid search using the SemiCoherentGlitchSearch """
@helper_functions.initializer
def __init__(self, label, outdir, sftfilepath=None, F0s=[0],
F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None,
Alphas=[0], Deltas=[0], tref=None, minStartTime=None,
maxStartTime=None, minCoverFreq=None, maxCoverFreq=None,
write_after=1000, earth_ephem=None, sun_ephem=None):
"""
Parameters
----------
label, outdir: str
A label and directory to read/write data from/to
sftfilepath: str
File patern to match SFTs
F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple
Length 3 tuple describing the grid for each parameter, e.g
[F0min, F0max, dF0], for a fixed value simply give [F0].
tref, minStartTime, maxStartTime: int
GPS seconds of the reference time, start time and end time
For all other parameters, see pyfstat.ComputeFStat.
"""
if tglitchs is None:
self.tglitchs = [self.maxStartTime]
if earth_ephem is None:
self.earth_ephem = self.earth_ephem_default
if sun_ephem is None:
self.sun_ephem = self.sun_ephem_default
self.search = SemiCoherentGlitchSearch(
label=label, outdir=outdir, sftfilepath=self.sftfilepath,
tref=tref, minStartTime=minStartTime, maxStartTime=maxStartTime,
minCoverFreq=minCoverFreq, maxCoverFreq=maxCoverFreq,
earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
BSGL=self.BSGL)
if os.path.isdir(outdir) is False:
os.mkdir(outdir)
self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label)
self.keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta', 'delta_F0',
'delta_F1', 'tglitch']
def get_input_data_array(self):
arrays = []
for tup in (self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas,
self.delta_F0s, self.delta_F1s, self.tglitchs):
arrays.append(self.get_array_from_tuple(tup))
input_data = []
for vals in itertools.product(*arrays):
input_data.append(vals)
self.arrays = arrays
self.input_data = np.array(input_data)
"""
Provides helpful functions to facilitate ease-of-use of pyfstat
"""
import os
import sys
import argparse
import logging
import inspect
from functools import wraps
import matplotlib.pyplot as plt
import numpy as np
def set_up_optional_tqdm():
try:
from tqdm import tqdm
except ImportError:
def tqdm(x, *args, **kwargs):
return x
return tqdm
def set_up_matplotlib_defaults():
plt.switch_backend('Agg')
plt.rcParams['text.usetex'] = True
plt.rcParams['axes.formatter.useoffset'] = False
def set_up_command_line_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("-q", "--quite", help="Decrease output verbosity",
action="store_true")
parser.add_argument("--no-interactive", help="Don't use interactive",
action="store_true")
parser.add_argument("-c", "--clean", help="Don't use cached data",
action="store_true")
parser.add_argument("-u", "--use-old-data", action="store_true")
parser.add_argument('-s', "--setup-only", action="store_true")
parser.add_argument('-n', "--no-template-counting", action="store_true")
parser.add_argument('unittest_args', nargs='*')
args, unknown = parser.parse_known_args()
sys.argv[1:] = args.unittest_args
if args.quite or args.no_interactive:
def tqdm(x, *args, **kwargs):
return x
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
if args.quite:
stream_handler.setLevel(logging.WARNING)
else:
stream_handler.setLevel(logging.DEBUG)
stream_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)-8s: %(message)s', datefmt='%H:%M'))
logger.addHandler(stream_handler)
return args
def set_up_ephemeris_configuration():
config_file = os.path.expanduser('~')+'/.pyfstat.conf'
if os.path.isfile(config_file):
d = {}
with open(config_file, 'r') as f:
for line in f:
k, v = line.split('=')
k = k.replace(' ', '')
for item in [' ', "'", '"', '\n']:
v = v.replace(item, '')
d[k] = v
earth_ephem = d['earth_ephem']
sun_ephem = d['sun_ephem']
else:
logging.warning('No ~/.pyfstat.conf file found please provide the '
'paths when initialising searches')
earth_ephem = None
sun_ephem = None
return earth_ephem, sun_ephem
def round_to_n(x, n):
if not x:
return 0
power = -int(np.floor(np.log10(abs(x)))) + (n - 1)
factor = (10 ** power)
return round(x * factor) / factor
def texify_float(x, d=2):
if type(x) == str:
return x
x = round_to_n(x, d)
if 0.01 < abs(x) < 100:
return str(x)
else:
power = int(np.floor(np.log10(abs(x))))
stem = np.round(x / 10**power, d)
if d == 1:
stem = int(stem)
return r'${}{{\times}}10^{{{}}}$'.format(stem, power)
def initializer(func):
""" Decorator function to automatically assign the parameters to self """
names, varargs, keywords, defaults = inspect.getargspec(func)
@wraps(func)
def wrapper(self, *args, **kargs):
for name, arg in list(zip(names[1:], args)) + list(kargs.items()):
setattr(self, name, arg)
for name, default in zip(reversed(names), reversed(defaults)):
if not hasattr(self, name):
setattr(self, name, default)
func(self, *args, **kargs)
return wrapper
This diff is collapsed.
"""
Provides functions to aid in calculating the optimal setup based on the metric
volume estimates.
"""
import logging
import numpy as np
import scipy.optimize
import lal
import lalpulsar
def get_optimal_setup(
R, Nsegs0, tref, minStartTime, maxStartTime, DeltaOmega,
DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem):
logging.info('Calculating optimal setup for R={}, Nsegs0={}'.format(
R, Nsegs0))
V_0 = get_V_estimate(
Nsegs0, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs,
fiducial_freq, detector_names, earth_ephem, sun_ephem)
logging.info('Stage {}, nsegs={}, V={}'.format(0, Nsegs0, V_0))
nsegs_vals = [Nsegs0]
V_vals = [V_0]
i = 0
nsegs_i = Nsegs0
while nsegs_i > 1:
nsegs_i, V_i = get_nsegs_ip1(
nsegs_i, R, tref, minStartTime, maxStartTime, DeltaOmega,
DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem)
nsegs_vals.append(nsegs_i)
V_vals.append(V_i)
i += 1
logging.info(
'Stage {}, nsegs={}, V={}'.format(i, nsegs_i, V_i))
return nsegs_vals, V_vals
def get_nsegs_ip1(
nsegs_i, R, tref, minStartTime, maxStartTime, DeltaOmega,
DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem):
log10R = np.log10(R)
log10Vi = np.log10(get_V_estimate(
nsegs_i, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs,
fiducial_freq, detector_names, earth_ephem, sun_ephem))
def f(nsegs_ip1):
if nsegs_ip1[0] > nsegs_i:
return 1e6
if nsegs_ip1[0] < 0:
return 1e6
nsegs_ip1 = int(nsegs_ip1[0])
if nsegs_ip1 == 0:
nsegs_ip1 = 1
Vip1 = get_V_estimate(
nsegs_ip1, tref, minStartTime, maxStartTime, DeltaOmega,
DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem)
if Vip1[0] is None:
return 1e6
else:
log10Vip1 = np.log10(Vip1)
return np.abs(log10Vi[0] + log10R - log10Vip1[0])
res = scipy.optimize.minimize(f, .5*nsegs_i, method='Powell', tol=0.1,
options={'maxiter': 10})
nsegs_ip1 = int(res.x)
if nsegs_ip1 == 0:
nsegs_ip1 = 1
if res.success:
return nsegs_ip1, get_V_estimate(
nsegs_ip1, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs,
fiducial_freq, detector_names, earth_ephem, sun_ephem)
else:
raise ValueError('Optimisation unsuccesful')
def get_V_estimate(
nsegs, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs,
fiducial_freq, detector_names, earth_ephem, sun_ephem):
""" Returns V, Vsky, Vpe estimated from the super-sky metric
Parameters
----------
nsegs: int
Number of semi-coherent segments
tref: int
Reference time in GPS seconds
minStartTime, maxStartTime: int
Minimum and maximum SFT timestamps
DeltaOmega: float
Solid angle of the sky-patch
DeltaFs: array
Array of [DeltaF0, DeltaF1, ...], length determines the number of
spin-down terms.
fiducial_freq: float
Fidicual frequency
detector_names: array
Array of detectors to average over
earth_ephem, sun_ephem: st
Paths to the ephemeris files
"""
spindowns = len(DeltaFs) - 1
tboundaries = np.linspace(minStartTime, maxStartTime, nsegs+1)
ref_time = lal.LIGOTimeGPS(tref)
segments = lal.SegListCreate()
for j in range(len(tboundaries)-1):
seg = lal.SegCreate(lal.LIGOTimeGPS(tboundaries[j]),
lal.LIGOTimeGPS(tboundaries[j+1]),
j)
lal.SegListAppend(segments, seg)
detNames = lal.CreateStringVector(*detector_names)
detectors = lalpulsar.MultiLALDetector()
lalpulsar.ParseMultiLALDetector(detectors, detNames)
detector_weights = None
detector_motion = (lalpulsar.DETMOTION_SPIN
+ lalpulsar.DETMOTION_ORBIT)
ephemeris = lalpulsar.InitBarycenter(earth_ephem, sun_ephem)
try:
SSkyMetric = lalpulsar.ComputeSuperskyMetrics(
spindowns, ref_time, segments, fiducial_freq, detectors,
detector_weights, detector_motion, ephemeris)
except RuntimeError as e:
logging.debug('Encountered run-time error {}'.format(e))
return None, None, None
sqrtdetG_SKY = np.sqrt(np.linalg.det(
SSkyMetric.semi_rssky_metric.data[:2, :2]))
sqrtdetG_PE = np.sqrt(np.linalg.det(
SSkyMetric.semi_rssky_metric.data[2:, 2:]))
Vsky = .5*sqrtdetG_SKY*DeltaOmega
Vpe = sqrtdetG_PE * np.prod(DeltaFs)
if Vsky == 0:
Vsky = 1
if Vpe == 0:
Vpe = 1
return (Vsky * Vpe, Vsky, Vpe)
......@@ -3,8 +3,8 @@
from distutils.core import setup
setup(name='PyFstat',
version='0.1',
version='0.2',
author='Gregory Ashton',
author_email='gregory.ashton@ligo.org',
py_modules=['pyfstat'],
packages=['pyfstat'],
)
......@@ -208,7 +208,7 @@ class TestAuxillaryFunctions(Test):
def test_get_V_estimate_sky_F0_F1(self):
out =