helper_functions.py 8.68 KB
Newer Older
Gregory Ashton's avatar
Gregory Ashton committed
1
2
3
4
5
6
"""
Provides helpful functions to facilitate ease-of-use of pyfstat
"""

import os
import sys
7
import subprocess
Gregory Ashton's avatar
Gregory Ashton committed
8
9
10
import argparse
import logging
import inspect
11
import peakutils
Gregory Ashton's avatar
Gregory Ashton committed
12
from functools import wraps
13
from scipy.stats.distributions import ncx2
14
import lal
15
import lalpulsar
Gregory Ashton's avatar
Gregory Ashton committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

import matplotlib.pyplot as plt
import numpy as np


def set_up_optional_tqdm():
    try:
        from tqdm import tqdm
    except ImportError:
        def tqdm(x, *args, **kwargs):
            return x
    return tqdm


def set_up_matplotlib_defaults():
    plt.switch_backend('Agg')
    plt.rcParams['text.usetex'] = True
    plt.rcParams['axes.formatter.useoffset'] = False


def set_up_command_line_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("-q", "--quite", help="Decrease output verbosity",
                        action="store_true")
40
41
    parser.add_argument("-v", "--verbose", help="Increase output verbosity",
                        action="store_true")
Gregory Ashton's avatar
Gregory Ashton committed
42
43
44
45
46
47
48
49
50
51
52
53
54
    parser.add_argument("--no-interactive", help="Don't use interactive",
                        action="store_true")
    parser.add_argument("-c", "--clean", help="Don't use cached data",
                        action="store_true")
    parser.add_argument("-u", "--use-old-data", action="store_true")
    parser.add_argument('-s', "--setup-only", action="store_true")
    parser.add_argument('-n', "--no-template-counting", action="store_true")
    parser.add_argument('unittest_args', nargs='*')
    args, unknown = parser.parse_known_args()
    sys.argv[1:] = args.unittest_args
    if args.quite or args.no_interactive:
        def tqdm(x, *args, **kwargs):
            return x
55
56
    else:
        tqdm = set_up_optional_tqdm()
Gregory Ashton's avatar
Gregory Ashton committed
57
    logger = logging.getLogger()
58
    logger.setLevel(logging.INFO)
Gregory Ashton's avatar
Gregory Ashton committed
59
60
61
    stream_handler = logging.StreamHandler()
    if args.quite:
        stream_handler.setLevel(logging.WARNING)
62
    elif args.verbose:
Gregory Ashton's avatar
Gregory Ashton committed
63
        stream_handler.setLevel(logging.DEBUG)
64
65
    else:
        stream_handler.setLevel(logging.INFO)
Gregory Ashton's avatar
Gregory Ashton committed
66
67
68
    stream_handler.setFormatter(logging.Formatter(
        '%(asctime)s %(levelname)-8s: %(message)s', datefmt='%H:%M'))
    logger.addHandler(stream_handler)
69
    return args, tqdm
Gregory Ashton's avatar
Gregory Ashton committed
70
71
72


def set_up_ephemeris_configuration():
73
    """ Returns the earth_ephem and sun_ephem """
Gregory Ashton's avatar
Gregory Ashton committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    config_file = os.path.expanduser('~')+'/.pyfstat.conf'
    if os.path.isfile(config_file):
        d = {}
        with open(config_file, 'r') as f:
            for line in f:
                k, v = line.split('=')
                k = k.replace(' ', '')
                for item in [' ', "'", '"', '\n']:
                    v = v.replace(item, '')
                d[k] = v
        earth_ephem = d['earth_ephem']
        sun_ephem = d['sun_ephem']
    else:
        logging.warning('No ~/.pyfstat.conf file found please provide the '
                        'paths when initialising searches')
        earth_ephem = None
        sun_ephem = None
    return earth_ephem, sun_ephem


def round_to_n(x, n):
    if not x:
        return 0
    power = -int(np.floor(np.log10(abs(x)))) + (n - 1)
    factor = (10 ** power)
    return round(x * factor) / factor


def texify_float(x, d=2):
Gregory Ashton's avatar
Gregory Ashton committed
103
104
    if x == 0:
        return 0
Gregory Ashton's avatar
Gregory Ashton committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    if type(x) == str:
        return x
    x = round_to_n(x, d)
    if 0.01 < abs(x) < 100:
        return str(x)
    else:
        power = int(np.floor(np.log10(abs(x))))
        stem = np.round(x / 10**power, d)
        if d == 1:
            stem = int(stem)
        return r'${}{{\times}}10^{{{}}}$'.format(stem, power)


def initializer(func):
    """ Decorator function to automatically assign the parameters to self """
    names, varargs, keywords, defaults = inspect.getargspec(func)

    @wraps(func)
    def wrapper(self, *args, **kargs):
        for name, arg in list(zip(names[1:], args)) + list(kargs.items()):
            setattr(self, name, arg)

        for name, default in zip(reversed(names), reversed(defaults)):
            if not hasattr(self, name):
                setattr(self, name, default)

        func(self, *args, **kargs)

    return wrapper

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

def get_peak_values(frequencies, twoF, threshold_2F, F0=None, F0range=None):
    if F0:
        cut_idxs = np.abs(frequencies - F0) < F0range
        frequencies = frequencies[cut_idxs]
        twoF = twoF[cut_idxs]
    idxs = peakutils.indexes(twoF, thres=1.*threshold_2F/np.max(twoF))
    F0maxs = frequencies[idxs]
    twoFmaxs = twoF[idxs]
    freq_err = frequencies[1] - frequencies[0]
    return F0maxs, twoFmaxs, freq_err*np.ones(len(idxs))


def get_comb_values(F0, frequencies, twoF, period, N=4):
    if period == 'sidereal':
        period = 23*60*60 + 56*60 + 4.0616
    elif period == 'terrestrial':
        period = 86400
    freq_err = frequencies[1] - frequencies[0]
    comb_frequencies = [n*1/period for n in range(-N, N+1)]
    comb_idxs = [np.argmin(np.abs(frequencies-F0-F)) for F in comb_frequencies]
    return comb_frequencies, twoF[comb_idxs], freq_err*np.ones(len(comb_idxs))

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188

def compute_P_twoFstarcheck(twoFstarcheck, twoFcheck, M0, plot=False):
    """ Returns the unnormalised pdf of twoFstarcheck given twoFcheck """
    upper = 4+twoFstarcheck + 0.5*(2*(4*M0+2*twoFcheck))
    rho2starcheck = np.linspace(1e-1, upper, 500)
    integrand = (ncx2.pdf(twoFstarcheck, 4*M0, rho2starcheck)
                 * ncx2.pdf(twoFcheck, 4, rho2starcheck))
    if plot:
        fig, ax = plt.subplots()
        ax.plot(rho2starcheck, integrand)
        fig.savefig('test')
    return np.trapz(integrand, rho2starcheck)


def compute_pstar(twoFcheck_obs, twoFstarcheck_obs, m0, plot=False):
    M0 = 2*m0 + 1
    upper = 4+twoFcheck_obs + (2*(4*M0+2*twoFcheck_obs))
    twoFstarcheck_vals = np.linspace(1e-1, upper, 500)
    P_twoFstarcheck = np.array(
        [compute_P_twoFstarcheck(twoFstarcheck, twoFcheck_obs, M0)
         for twoFstarcheck in twoFstarcheck_vals])
    C = np.trapz(P_twoFstarcheck, twoFstarcheck_vals)
    idx = np.argmin(np.abs(twoFstarcheck_vals - twoFstarcheck_obs))
    if plot:
        fig, ax = plt.subplots()
        ax.plot(twoFstarcheck_vals, P_twoFstarcheck)
        ax.fill_between(twoFstarcheck_vals[:idx+1], 0, P_twoFstarcheck[:idx+1])
        ax.axvline(twoFstarcheck_vals[idx])
        fig.savefig('test')
    pstar_l = np.trapz(P_twoFstarcheck[:idx+1]/C, twoFstarcheck_vals[:idx+1])
    return 2*np.min([pstar_l, 1-pstar_l])
189
190


191
192
def run_commandline(cl, log_level=20):
    """Run a string cmd as a subprocess, check for errors and return output.
193

194
195
196
197
198
199
200
201
202
203
204
    Parameters
    ----------
    cl: str
        Command to run
    log_level: int
        See https://docs.python.org/2/library/logging.html#logging-levels,
        default is '20' (INFO)

    """

    logging.log(log_level, 'Now executing: ' + cl)
205
206
207
208
209
210
211
212
213
214
215
216
217
    try:
        out = subprocess.check_output(cl,                       # what to run
                                      stderr=subprocess.STDOUT, # catch errors
                                      shell=True,               # proper environment etc
                                      universal_newlines=True   # properly display linebreaks in error/output printing
                                     )
    except subprocess.CalledProcessError as e:
        logging.error('Execution failed:')
        logging.error(e.output)
        raise
    os.system('\n')

    return(out)
218

219

220
def convert_array_to_gsl_matrix(array):
221
    gsl_matrix = lal.gsl_matrix(*array.shape)
222
223
    gsl_matrix.data = array
    return gsl_matrix
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241


def get_sft_array(sftfilepattern, data_duration, F0, dF0):
    """ Return the raw data from a set of sfts """

    SFTCatalog = lalpulsar.SFTdataFind(
        sftfilepattern, lalpulsar.SFTConstraints())
    MultiSFTs = lalpulsar.LoadMultiSFTs(SFTCatalog, F0-dF0, F0+dF0)
    SFTs = MultiSFTs.data[0]
    data = []
    for sft in SFTs.data:
        data.append(np.abs(sft.data.data))
    data = np.array(data).T
    n, nsfts = data.shape
    freqs = np.linspace(sft.f0, sft.f0+n*sft.deltaF, n)
    times = np.linspace(0, data_duration, nsfts)

    return times, freqs, data
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264


def get_covering_band(tref, tstart, tend, uniform_prior):
    tref = lal.LIGOTimeGPS(tref)
    tstart = lal.LIGOTimeGPS(tstart)
    tend = lal.LIGOTimeGPS(tend)
    psr = lalpulsar.PulsarSpinRange()
    for i, key in enumerate(['F0', 'F1', 'F2']):
        if key in uniform_prior:
            if type(uniform_prior[key]) == dict and uniform_prior[key]['type'] == 'unif':
                l, u = uniform_prior[key]['lower'], uniform_prior[key]['upper']
                psr.fkdot[i] = (l+u)/2.
                psr.fkdotBand[i] = u-l
            else:
                psr.fkdot[i] = uniform_prior[key]
                psr.fkdotBand[i] = 0
        else:
            raise ValueError(
                'uniform_prior should contain unif or const values of F0, F1, F2')
    psr.refTime = tref
    return lalpulsar.CWSignalCoveringBand(tstart, tend, psr, 0, 0, 0)