pyfstat.py 77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
""" Classes for various types of searches using ComputeFstatistic """
import os
import sys
import itertools
import logging
import argparse
import copy
import glob
import inspect
from functools import wraps
11
import subprocess
12
from collections import OrderedDict
13
14
15

import numpy as np
import matplotlib
16
matplotlib.use('Agg')
17
18
19
20
import matplotlib.pyplot as plt
import emcee
import corner
import dill as pickle
21
import lal
22
23
import lalpulsar

24
25
26
27
28
29
try:
    from tqdm import tqdm
except ImportError:
    def tqdm(x):
        return x

30
plt.rcParams['text.usetex'] = True
31
plt.rcParams['axes.formatter.useoffset'] = False
32

33
34
35
36
37
38
39
config_file = os.path.expanduser('~')+'/.pyfstat.conf'
if os.path.isfile(config_file):
    d = {}
    with open(config_file, 'r') as f:
        for line in f:
            k, v = line.split('=')
            k = k.replace(' ', '')
40
            v = v.replace(' ', '').replace("'", "").replace('"', '').replace('\n', '')
41
42
43
44
45
46
            d[k] = v
    earth_ephem = d['earth_ephem']
    sun_ephem = d['sun_ephem']
else:
    logging.warning('No ~/.pyfstat.conf file found please provide the paths '
                    'when initialising searches')
47
48
49
    earth_ephem = None
    sun_ephem = None

50
51
52
53
54
parser = argparse.ArgumentParser()
parser.add_argument("-q", "--quite", help="Decrease output verbosity",
                    action="store_true")
parser.add_argument("-c", "--clean", help="Don't use cached data",
                    action="store_true")
55
parser.add_argument("-u", "--use-old-data", action="store_true")
56
57
58
59
parser.add_argument('unittest_args', nargs='*')
args, unknown = parser.parse_known_args()
sys.argv[1:] = args.unittest_args

Gregory Ashton's avatar
Gregory Ashton committed
60
61
62
63

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
64
if args.quite:
Gregory Ashton's avatar
Gregory Ashton committed
65
    stream_handler.setLevel(logging.WARNING)
66
else:
Gregory Ashton's avatar
Gregory Ashton committed
67
68
69
70
    stream_handler.setLevel(logging.DEBUG)
stream_handler.setFormatter(logging.Formatter(
    '%(asctime)s %(levelname)-8s: %(message)s', datefmt='%H:%M'))
logger.addHandler(stream_handler)
71

72
73

def initializer(func):
74
    """ Automatically assigns the parameters to self """
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    names, varargs, keywords, defaults = inspect.getargspec(func)

    @wraps(func)
    def wrapper(self, *args, **kargs):
        for name, arg in list(zip(names[1:], args)) + list(kargs.items()):
            setattr(self, name, arg)

        for name, default in zip(reversed(names), reversed(defaults)):
            if not hasattr(self, name):
                setattr(self, name, default)

        func(self, *args, **kargs)

    return wrapper


def read_par(label, outdir):
92
    """ Read in a .par file, returns a dictionary of the values """
93
94
95
96
    filename = '{}/{}.par'.format(outdir, label)
    d = {}
    with open(filename, 'r') as f:
        for line in f:
97
98
99
            if len(line.split('=')) > 1:
                key, val = line.rstrip('\n').split(' = ')
                key = key.strip()
100
                d[key] = np.float64(eval(val.rstrip('; ')))
101
102
103
104
    return d


class BaseSearchClass(object):
105
    """ The base search class, provides ephemeris and general utilities """
106
107
108
109

    earth_ephem_default = earth_ephem
    sun_ephem_default = sun_ephem

110
111
112
113
    def add_log_file(self):
        ' Log output to a log-file, requires class to have outdir and label '
        logfilename = '{}/{}.log'.format(self.outdir, self.label)
        fh = logging.FileHandler(logfilename)
Gregory Ashton's avatar
Gregory Ashton committed
114
        fh.setLevel(logging.INFO)
115
116
117
118
119
        fh.setFormatter(logging.Formatter(
            '%(asctime)s %(levelname)-8s: %(message)s',
            datefmt='%y-%m-%d %H:%M'))
        logging.getLogger().addHandler(fh)

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    def shift_matrix(self, n, dT):
        """ Generate the shift matrix """
        m = np.zeros((n, n))
        factorial = np.math.factorial
        for i in range(n):
            for j in range(n):
                if i == j:
                    m[i, j] = 1.0
                elif i > j:
                    m[i, j] = 0.0
                else:
                    if i == 0:
                        m[i, j] = 2*np.pi*float(dT)**(j-i) / factorial(j-i)
                    else:
                        m[i, j] = float(dT)**(j-i) / factorial(j-i)

        return m

    def shift_coefficients(self, theta, dT):
        """ Shift a set of coefficients by dT

        Parameters
        ----------
        theta: array-like, shape (n,)
            vector of the expansion coefficients to transform starting from the
145
            lowest degree e.g [phi, F0, F1,...].
146
        dT: float
147
            difference between the two reference times as tref_new - tref_old.
148
149
150
151

        Returns
        -------
        theta_new: array-like shape (n,)
152
            vector of the coefficients as evaluate as the new reference time.
153
154
155
156
157
        """
        n = len(theta)
        m = self.shift_matrix(n, dT)
        return np.dot(m, theta)

158
    def calculate_thetas(self, theta, delta_thetas, tbounds, theta0_idx=0):
159
160
161
        """ Calculates the set of coefficients for the post-glitch signal """
        thetas = [theta]
        for i, dt in enumerate(delta_thetas):
162
163
164
165
166
167
168
169
170
171
172
173
174
            if i < theta0_idx:
                pre_theta_at_ith_glitch = self.shift_coefficients(
                    thetas[0], tbounds[i+1] - self.tref)
                post_theta_at_ith_glitch = pre_theta_at_ith_glitch - dt
                thetas.insert(0, self.shift_coefficients(
                    post_theta_at_ith_glitch, self.tref - tbounds[i+1]))

            elif i >= theta0_idx:
                pre_theta_at_ith_glitch = self.shift_coefficients(
                    thetas[i], tbounds[i+1] - self.tref)
                post_theta_at_ith_glitch = pre_theta_at_ith_glitch + dt
                thetas.append(self.shift_coefficients(
                    post_theta_at_ith_glitch, self.tref - tbounds[i+1]))
175
176
177
        return thetas


Gregory Ashton's avatar
Gregory Ashton committed
178
179
180
181
182
183
184
class ComputeFstat(object):
    """ Base class providing interface to lalpulsar.ComputeFstat """

    earth_ephem_default = earth_ephem
    sun_ephem_default = sun_ephem

    @initializer
185
    def __init__(self, tref, sftfilepath=None,
186
                 minStartTime=None, maxStartTime=None,
Gregory Ashton's avatar
Gregory Ashton committed
187
                 minCoverFreq=None, maxCoverFreq=None,
188
                 detector=None, earth_ephem=None, sun_ephem=None,
189
                 binary=False, transient=True, BSGL=False):
190
191
192
193
194
        """
        Parameters
        ----------
        tref: int
            GPS seconds of the reference time.
195
196
        sftfilepath: str
            File patern to match SFTs
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        minCoverFreq, maxCoverFreq: float
            The min and max cover frequency passed to CreateFstatInput, if
            either is None the range of frequencies in the SFT less 1Hz is
            used.
        detector: str
            Two character reference to the data to use, specify None for no
            contraint.
        earth_ephem, sun_ephem: str
            Paths of the two files containing positions of Earth and Sun,
            respectively at evenly spaced times, as passed to CreateFstatInput.
            If None defaults defined in BaseSearchClass will be used.
        binary: bool
            If true, search of binary parameters.
        transient: bool
            If true, allow for the Fstat to be computed over a transient range.
212
213
        BSGL: bool
            If true, compute the BSGL rather than the twoF value.
214
215

        """
Gregory Ashton's avatar
Gregory Ashton committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

        if earth_ephem is None:
            self.earth_ephem = self.earth_ephem_default
        if sun_ephem is None:
            self.sun_ephem = self.sun_ephem_default

        self.init_computefstatistic_single_point()

    def init_computefstatistic_single_point(self):
        """ Initilisation step of run_computefstatistic for a single point """

        logging.info('Initialising SFTCatalog')
        constraints = lalpulsar.SFTConstraints()
        if self.detector:
            constraints.detector = self.detector
231
232
233
234
235
        if self.minStartTime:
            constraints.minStartTime = lal.LIGOTimeGPS(self.minStartTime)
        if self.maxStartTime:
            constraints.maxStartTime = lal.LIGOTimeGPS(self.maxStartTime)

236
        logging.info('Loading data matching pattern {}'.format(
237
238
                     self.sftfilepath))
        SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepath, constraints)
Gregory Ashton's avatar
Gregory Ashton committed
239
        names = list(set([d.header.name for d in SFTCatalog.data]))
240
        epochs = [d.header.epoch for d in SFTCatalog.data]
241
        logging.info(
242
243
            'Loaded {} data files from detectors {} spanning {} to {}'.format(
                len(epochs), names, int(epochs[0]), int(epochs[-1])))
Gregory Ashton's avatar
Gregory Ashton committed
244
245
246
247
248
249

        logging.info('Initialising ephems')
        ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem)

        logging.info('Initialising FstatInput')
        dFreq = 0
250
251
252
253
254
        if self.transient:
            self.whatToCompute = lalpulsar.FSTATQ_ATOMS_PER_DET
        else:
            self.whatToCompute = lalpulsar.FSTATQ_2F

Gregory Ashton's avatar
Gregory Ashton committed
255
256
257
258
259
260
261
262
        FstatOptionalArgs = lalpulsar.FstatOptionalArgsDefaults

        if self.minCoverFreq is None or self.maxCoverFreq is None:
            fA = SFTCatalog.data[0].header.f0
            numBins = SFTCatalog.data[0].numBins
            fB = fA + (numBins-1)*SFTCatalog.data[0].header.deltaF
            self.minCoverFreq = fA + 0.5
            self.maxCoverFreq = fB - 0.5
263
264
265
            logging.info('Min/max cover freqs not provided, using '
                         '{} and {}, est. from SFTs'.format(
                             self.minCoverFreq, self.maxCoverFreq))
Gregory Ashton's avatar
Gregory Ashton committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285

        self.FstatInput = lalpulsar.CreateFstatInput(SFTCatalog,
                                                     self.minCoverFreq,
                                                     self.maxCoverFreq,
                                                     dFreq,
                                                     ephems,
                                                     FstatOptionalArgs
                                                     )

        logging.info('Initialising PulsarDoplerParams')
        PulsarDopplerParams = lalpulsar.PulsarDopplerParams()
        PulsarDopplerParams.refTime = self.tref
        PulsarDopplerParams.Alpha = 1
        PulsarDopplerParams.Delta = 1
        PulsarDopplerParams.fkdot = np.array([0, 0, 0, 0, 0, 0, 0])
        self.PulsarDopplerParams = PulsarDopplerParams

        logging.info('Initialising FstatResults')
        self.FstatResults = lalpulsar.FstatResults()

286
287
288
289
        if self.BSGL:
            logging.info('Initialising BSGL: this will fail if numDet < 2')
            # Tuning parameters - to be reviewed
            numDetectors = 2
Gregory Ashton's avatar
Gregory Ashton committed
290
            Fstar0sc = 15.
291
            oLGX = np.zeros(10)
Gregory Ashton's avatar
Gregory Ashton committed
292
            oLGX[:numDetectors] = 1./numDetectors
293
294
295
            self.BSGLSetup = lalpulsar.CreateBSGLSetup(numDetectors,
                                                       Fstar0sc,
                                                       oLGX,
296
                                                       True,
297
298
                                                       1)
            self.twoFX = np.zeros(10)
Gregory Ashton's avatar
Gregory Ashton committed
299
            self.whatToCompute = (self.whatToCompute +
300
301
                                  lalpulsar.FSTATQ_2F_PER_DET)

302
        if self.transient:
303
            logging.info('Initialising transient parameters')
304
305
306
307
308
309
            self.windowRange = lalpulsar.transientWindowRange_t()
            self.windowRange.type = lalpulsar.TRANSIENT_RECTANGULAR
            self.windowRange.t0Band = 0
            self.windowRange.dt0 = 1
            self.windowRange.tauBand = 0
            self.windowRange.dtau = 1
310

Gregory Ashton's avatar
Gregory Ashton committed
311
    def run_computefstatistic_single_point(self, tstart, tend, F0, F1,
312
313
314
                                           F2, Alpha, Delta, asini=None,
                                           period=None, ecc=None, tp=None,
                                           argp=None):
315
        """ Returns the twoF fully-coherently at a single point """
Gregory Ashton's avatar
Gregory Ashton committed
316

317
318
        BSGL_PREFACTOR = 10 * 1 / np.log10(np.exp(1))

Gregory Ashton's avatar
Gregory Ashton committed
319
320
321
        self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0])
        self.PulsarDopplerParams.Alpha = Alpha
        self.PulsarDopplerParams.Delta = Delta
322
323
324
325
326
327
        if self.binary:
            self.PulsarDopplerParams.asini = asini
            self.PulsarDopplerParams.period = period
            self.PulsarDopplerParams.ecc = ecc
            self.PulsarDopplerParams.tp = tp
            self.PulsarDopplerParams.argp = argp
Gregory Ashton's avatar
Gregory Ashton committed
328
329
330
331

        lalpulsar.ComputeFstat(self.FstatResults,
                               self.FstatInput,
                               self.PulsarDopplerParams,
332
                               1,
Gregory Ashton's avatar
Gregory Ashton committed
333
334
335
                               self.whatToCompute
                               )

336
        if self.transient is False:
337
338
339
340
341
342
343
344
            if self.BSGL is False:
                return self.FstatResults.twoF[0]

            twoF = np.float(self.FstatResults.twoF[0])
            self.twoFX[0] = self.FstatResults.twoFPerDet(0)
            self.twoFX[1] = self.FstatResults.twoFPerDet(1)
            BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX,
                                         self.BSGLSetup)
345
            return BSGL_PREFACTOR * BSGL
346

347
348
        self.windowRange.t0 = int(tstart)  # TYPE UINT4
        self.windowRange.tau = int(tend - tstart)  # TYPE UINT4
349

Gregory Ashton's avatar
Gregory Ashton committed
350
        FS = lalpulsar.ComputeTransientFstatMap(
351
            self.FstatResults.multiFatoms[0], self.windowRange, False)
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369

        if self.BSGL is False:
            return 2*FS.F_mn.data[0][0]

        FstatResults_single = copy.copy(self.FstatResults)
        FstatResults_single.lenth = 1
        FstatResults_single.data = self.FstatResults.multiFatoms[0].data[0]
        FS0 = lalpulsar.ComputeTransientFstatMap(
            FstatResults_single.multiFatoms[0], self.windowRange, False)
        FstatResults_single.data = self.FstatResults.multiFatoms[0].data[1]
        FS1 = lalpulsar.ComputeTransientFstatMap(
            FstatResults_single.multiFatoms[0], self.windowRange, False)

        self.twoFX[0] = 2*FS0.F_mn.data[0][0]
        self.twoFX[1] = 2*FS1.F_mn.data[0][0]
        BSGL = lalpulsar.ComputeBSGL(2*FS.F_mn.data[0][0], self.twoFX,
                                     self.BSGLSetup)

370
        return BSGL_PREFACTOR * BSGL
Gregory Ashton's avatar
Gregory Ashton committed
371
372
373


class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
374
375
376
377
378
379
380
381
382
    """ A semi-coherent glitch search

    This implements a basic `semi-coherent glitch F-stat in which the data
    is divided into two segments either side of the proposed glitch and the
    fully-coherent F-stat in each segment is averaged to give the semi-coherent
    F-stat
    """

    @initializer
Gregory Ashton's avatar
Gregory Ashton committed
383
    def __init__(self, label, outdir, tref, tstart, tend, nglitch=0,
384
                 sftfilepath=None, theta0_idx=0, BSGL=False,
385
386
387
                 minCoverFreq=None, maxCoverFreq=None, minStartTime=None,
                 maxStartTime=None, detector=None, earth_ephem=None,
                 sun_ephem=None):
388
389
390
391
        """
        Parameters
        ----------
        label, outdir: str
392
393
394
395
396
397
            A label and directory to read/write data from/to.
        tref, tstart, tend: int
            GPS seconds of the reference time, and start and end of the data.
        nglitch: int
            The (fixed) number of glitches; this can zero, but occasionally
            this causes issue (in which case just use ComputeFstat).
398
399
        sftfilepath: str
            File patern to match SFTs
400
401
402
403
        theta0_idx, int
            Index (zero-based) of which segment the theta refers to - uyseful
            if providing a tight prior on theta to allow the signal to jump
            too theta (and not just from)
404
        minCoverFreq, maxCoverFreq: float
405
406
407
            The min and max cover frequency passed to CreateFstatInput, if
            either is None the range of frequencies in the SFT less 1Hz is
            used.
408
409
        detector: str
            Two character reference to the data to use, specify None for no
410
            contraint.
411
412
        earth_ephem, sun_ephem: str
            Paths of the two files containing positions of Earth and Sun,
413
414
            respectively at evenly spaced times, as passed to CreateFstatInput.
            If None defaults defined in BaseSearchClass will be used.
415
416
417
418
419
420
421
        """

        self.fs_file_name = "{}/{}_FS.dat".format(self.outdir, self.label)
        if self.earth_ephem is None:
            self.earth_ephem = self.earth_ephem_default
        if self.sun_ephem is None:
            self.sun_ephem = self.sun_ephem_default
422
423
        self.transient = True
        self.binary = False
424
425
426
        self.init_computefstatistic_single_point()

    def compute_nglitch_fstat(self, F0, F1, F2, Alpha, Delta, *args):
427
        """ Returns the semi-coherent glitch summed twoF """
428
429
430

        args = list(args)
        tboundaries = [self.tstart] + args[-self.nglitch:] + [self.tend]
431
432
433
434
435
436
437
438
        delta_F0s = args[-3*self.nglitch:-2*self.nglitch]
        delta_F1s = args[-2*self.nglitch:-self.nglitch]
        delta_F2 = np.zeros(len(delta_F0s))
        delta_phi = np.zeros(len(delta_F0s))
        theta = [0, F0, F1, F2]
        delta_thetas = np.atleast_2d(
                np.array([delta_phi, delta_F0s, delta_F1s, delta_F2]).T)

439
440
        thetas = self.calculate_thetas(theta, delta_thetas, tboundaries,
                                       theta0_idx=self.theta0_idx)
441
442

        twoFSum = 0
443
        for i, theta_i_at_tref in enumerate(thetas):
444
445
446
            ts, te = tboundaries[i], tboundaries[i+1]

            twoFVal = self.run_computefstatistic_single_point(
447
448
                ts, te, theta_i_at_tref[1], theta_i_at_tref[2],
                theta_i_at_tref[3], Alpha, Delta)
449
450
            twoFSum += twoFVal

451
452
453
        if np.isfinite(twoFSum):
            return twoFSum
        else:
454
            return -np.inf
455
456
457

    def compute_glitch_fstat_single(self, F0, F1, F2, Alpha, Delta, delta_F0,
                                    delta_F1, tglitch):
458
459
460
461
        """ Returns the semi-coherent glitch summed twoF for nglitch=1

        Note: used for testing
        """
462
463
464
465
466
467
468
469
470
471
472

        theta = [F0, F1, F2]
        delta_theta = [delta_F0, delta_F1, 0]
        tref = self.tref

        theta_at_glitch = self.shift_coefficients(theta, tglitch - tref)
        theta_post_glitch_at_glitch = theta_at_glitch + delta_theta
        theta_post_glitch = self.shift_coefficients(
            theta_post_glitch_at_glitch, tref - tglitch)

        twoFsegA = self.run_computefstatistic_single_point(
Gregory Ashton's avatar
Gregory Ashton committed
473
            self.tstart, tglitch, theta[0], theta[1], theta[2], Alpha,
474
475
476
477
478
479
            Delta)

        if tglitch == self.tend:
            return twoFsegA

        twoFsegB = self.run_computefstatistic_single_point(
Gregory Ashton's avatar
Gregory Ashton committed
480
            tglitch, self.tend, theta_post_glitch[0],
481
482
483
484
485
486
            theta_post_glitch[1], theta_post_glitch[2], Alpha,
            Delta)

        return twoFsegA + twoFsegB


Gregory Ashton's avatar
Gregory Ashton committed
487
488
class MCMCSearch(BaseSearchClass):
    """ MCMC search using ComputeFstat"""
489
    @initializer
490
    def __init__(self, label, outdir, sftfilepath, theta_prior, tref,
491
                 tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1,
492
                 log10temperature_min=-5, theta_initial=None, scatter_val=1e-4,
493
494
495
                 binary=False, BSGL=False, minCoverFreq=None,
                 maxCoverFreq=None, detector=None, earth_ephem=None,
                 sun_ephem=None, theta0_idx=0):
496
497
498
499
        """
        Parameters
        label, outdir: str
            A label and directory to read/write data from/to
500
501
        sftfilepath: str
            File patern to match SFTs
502
        theta_prior: dict
503
504
505
506
            Dictionary of priors and fixed values for the search parameters.
            For each parameters (key of the dict), if it is to be held fixed
            the value should be the constant float, if it is be searched, the
            value should be a dictionary of the prior.
507
508
509
510
        theta_initial: dict, array, (None)
            Either a dictionary of distribution about which to distribute the
            initial walkers about, an array (from which the walkers will be
            scattered by scatter_val, or  None in which case the prior is used.
511
512
513
514
515
516
517
        tref, tstart, tend: int
            GPS seconds of the reference time, start time and end time
        nsteps: list (m,)
            List specifying the number of steps to take, the last two entries
            give the nburn and nprod of the 'production' run, all entries
            before are for iterative initialisation steps (usually just one)
            e.g. [1000, 1000, 500].
518
519
520
521
522
523
524
525
526
527
528
        nwalkers, ntemps: int,
            The number of walkers and temperates to use in the parallel
            tempered PTSampler.
        log10temperature_min float < 0
            The  log_10(tmin) value, the set of betas passed to PTSampler are
            generated from np.logspace(0, log10temperature_min, ntemps).
        binary: Bool
            If true, search over binary parameters
        detector: str
            Two character reference to the data to use, specify None for no
            contraint.
529
530
531
532
533
534
535
536
537
538
        minCoverFreq, maxCoverFreq: float
            Minimum and maximum instantaneous frequency which will be covered
            over the SFT time span as passed to CreateFstatInput
        earth_ephem, sun_ephem: str
            Paths of the two files containing positions of Earth and Sun,
            respectively at evenly spaced times, as passed to CreateFstatInput
            If None defaults defined in BaseSearchClass will be used

        """

539
540
541
        self.minStartTime = tstart
        self.maxStartTime = tend

Gregory Ashton's avatar
Gregory Ashton committed
542
543
        if os.path.isdir(outdir) is False:
            os.mkdir(outdir)
544
        self.add_log_file()
Gregory Ashton's avatar
Gregory Ashton committed
545
546
        logging.info(
            'Set-up MCMC search for model {} on data {}'.format(
547
                self.label, self.sftfilepath))
548
        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
Gregory Ashton's avatar
Gregory Ashton committed
549
550
        self.theta_prior['tstart'] = self.tstart
        self.theta_prior['tend'] = self.tend
551
552
        self.unpack_input_theta()
        self.ndim = len(self.theta_keys)
553
554
555
556
        if self.log10temperature_min:
            self.betas = np.logspace(0, self.log10temperature_min, self.ntemps)
        else:
            self.betas = None
557

558
559
560
561
562
563
564
565
566
        if earth_ephem is None:
            self.earth_ephem = self.earth_ephem_default
        if sun_ephem is None:
            self.sun_ephem = self.sun_ephem_default

        if args.clean and os.path.isfile(self.pickle_path):
            os.rename(self.pickle_path, self.pickle_path+".old")

        self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
567
568
569
        self.log_input()

    def log_input(self):
570
        logging.info('theta_prior = {}'.format(self.theta_prior))
571
        logging.info('nwalkers={}'.format(self.nwalkers))
572
573
574
575
        logging.info('scatter_val = {}'.format(self.scatter_val))
        logging.info('nsteps = {}'.format(self.nsteps))
        logging.info('ntemps = {}'.format(self.ntemps))
        logging.info('log10temperature_min = {}'.format(
576
            self.log10temperature_min))
577
578
579

    def inititate_search_object(self):
        logging.info('Setting up search object')
Gregory Ashton's avatar
Gregory Ashton committed
580
        self.search = ComputeFstat(
581
582
583
584
            tref=self.tref, sftfilepath=self.sftfilepath,
            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
            earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
            detector=self.detector, BSGL=self.BSGL, transient=False,
585
            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime)
586
587

    def logp(self, theta_vals, theta_prior, theta_keys, search):
Gregory Ashton's avatar
Gregory Ashton committed
588
        H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
589
590
591
592
593
594
             zip(theta_vals, theta_keys)]
        return np.sum(H)

    def logl(self, theta, search):
        for j, theta_i in enumerate(self.theta_idxs):
            self.fixed_theta[theta_i] = theta[j]
Gregory Ashton's avatar
Gregory Ashton committed
595
        FS = search.run_computefstatistic_single_point(*self.fixed_theta)
596
597
598
        return FS

    def unpack_input_theta(self):
Gregory Ashton's avatar
Gregory Ashton committed
599
600
        full_theta_keys = ['tstart', 'tend', 'F0', 'F1', 'F2', 'Alpha',
                           'Delta']
601
602
603
        if self.binary:
            full_theta_keys += [
                'asini', 'period', 'ecc', 'tp', 'argp']
604
605
        full_theta_keys_copy = copy.copy(full_theta_keys)

Gregory Ashton's avatar
Gregory Ashton committed
606
607
        full_theta_symbols = ['_', '_', '$f$', '$\dot{f}$', '$\ddot{f}$',
                              r'$\alpha$', r'$\delta$']
608
609
610
611
        if self.binary:
            full_theta_symbols += [
                'asini', 'period', 'period', 'ecc', 'tp', 'argp']

612
613
        self.theta_keys = []
        fixed_theta_dict = {}
614
        for key, val in self.theta_prior.iteritems():
615
616
            if type(val) is dict:
                fixed_theta_dict[key] = 0
Gregory Ashton's avatar
Gregory Ashton committed
617
                self.theta_keys.append(key)
618
619
620
621
622
623
            elif type(val) in [float, int, np.float64]:
                fixed_theta_dict[key] = val
            else:
                raise ValueError(
                    'Type {} of {} in theta not recognised'.format(
                        type(val), key))
Gregory Ashton's avatar
Gregory Ashton committed
624
            full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640

        if len(full_theta_keys_copy) > 0:
            raise ValueError(('Input dictionary `theta` is missing the'
                              'following keys: {}').format(
                                  full_theta_keys_copy))

        self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
        self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
        self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]

        idxs = np.argsort(self.theta_idxs)
        self.theta_idxs = [self.theta_idxs[i] for i in idxs]
        self.theta_symbols = [self.theta_symbols[i] for i in idxs]
        self.theta_keys = [self.theta_keys[i] for i in idxs]

    def check_initial_points(self, p0):
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
        for nt in range(self.ntemps):
            logging.info('Checking temperature {} chains'.format(nt))
            initial_priors = np.array([
                self.logp(p, self.theta_prior, self.theta_keys, self.search)
                for p in p0[nt]])
            number_of_initial_out_of_bounds = sum(initial_priors == -np.inf)

            if number_of_initial_out_of_bounds > 0:
                logging.warning(
                    'Of {} initial values, {} are -np.inf due to the prior'
                    .format(len(initial_priors),
                            number_of_initial_out_of_bounds))

                p0 = self.generate_new_p0_to_fix_initial_points(
                    p0, nt, initial_priors)

    def generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors):
        logging.info('Attempting to correct intial values')
        idxs = np.arange(self.nwalkers)[initial_priors == -np.inf]
        count = 0
        while sum(initial_priors == -np.inf) > 0 and count < 100:
            for j in idxs:
                p0[nt][j] = (p0[nt][np.random.randint(0, self.nwalkers)]*(
                             1+np.random.normal(0, 1e-10, self.ndim)))
            initial_priors = np.array([
                self.logp(p, self.theta_prior, self.theta_keys,
                          self.search)
                for p in p0[nt]])
            count += 1

        if sum(initial_priors == -np.inf) > 0:
            logging.info('Failed to fix initial priors')
        else:
            logging.info('Suceeded to fix initial priors')

        return p0
677

Gregory Ashton's avatar
Gregory Ashton committed
678
    def run_sampler_with_progress_bar(self, sampler, ns, p0):
679
680
        for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
            pass
Gregory Ashton's avatar
Gregory Ashton committed
681
682
683
        return sampler

    def run(self, proposal_scale_factor=2):
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699

        if self.old_data_is_okay_to_use is True:
            logging.warning('Using saved data from {}'.format(
                self.pickle_path))
            d = self.get_saved_data()
            self.sampler = d['sampler']
            self.samples = d['samples']
            self.lnprobs = d['lnprobs']
            self.lnlikes = d['lnlikes']
            return

        self.inititate_search_object()

        sampler = emcee.PTSampler(
            self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
            logpargs=(self.theta_prior, self.theta_keys, self.search),
700
            loglargs=(self.search,), betas=self.betas, a=proposal_scale_factor)
701

Gregory Ashton's avatar
Gregory Ashton committed
702
703
        p0 = self.generate_initial_p0()
        p0 = self.apply_corrections_to_p0(p0)
704
705
706
707
708
        self.check_initial_points(p0)

        ninit_steps = len(self.nsteps) - 2
        for j, n in enumerate(self.nsteps[:-2]):
            logging.info('Running {}/{} initialisation with {} steps'.format(
709
                j+1, ninit_steps, n))
Gregory Ashton's avatar
Gregory Ashton committed
710
            sampler = self.run_sampler_with_progress_bar(sampler, n, p0)
711
712
            logging.info("Mean acceptance fraction: {}"
                         .format(np.mean(sampler.acceptance_fraction, axis=1)))
713
714
715
            if self.ntemps > 1:
                logging.info("Tswap acceptance fraction: {}"
                             .format(sampler.tswap_acceptance_fraction))
Gregory Ashton's avatar
Gregory Ashton committed
716
            fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols)
717
718
719
            fig.savefig('{}/{}_init_{}_walkers.png'.format(
                self.outdir, self.label, j))

720
            p0 = self.get_new_p0(sampler)
Gregory Ashton's avatar
Gregory Ashton committed
721
            p0 = self.apply_corrections_to_p0(p0)
722
723
724
            self.check_initial_points(p0)
            sampler.reset()

Gregory Ashton's avatar
Gregory Ashton committed
725
726
727
728
        if len(self.nsteps) > 1:
            nburn = self.nsteps[-2]
        else:
            nburn = 0
729
730
731
        nprod = self.nsteps[-1]
        logging.info('Running final burn and prod with {} steps'.format(
            nburn+nprod))
Gregory Ashton's avatar
Gregory Ashton committed
732
        sampler = self.run_sampler_with_progress_bar(sampler, nburn+nprod, p0)
733
734
        logging.info("Mean acceptance fraction: {}"
                     .format(np.mean(sampler.acceptance_fraction, axis=1)))
735
736
737
        if self.ntemps > 1:
            logging.info("Tswap acceptance fraction: {}"
                         .format(sampler.tswap_acceptance_fraction))
738

Gregory Ashton's avatar
Gregory Ashton committed
739
740
        fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
                                      burnin_idx=nburn)
741
742
743
744
745
746
747
748
749
750
751
        fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label))

        samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
        lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1))
        lnlikes = sampler.lnlikelihood[0, :, nburn:].reshape((-1))
        self.sampler = sampler
        self.samples = samples
        self.lnprobs = lnprobs
        self.lnlikes = lnlikes
        self.save_data(sampler, samples, lnprobs, lnlikes)

752
    def plot_corner(self, figsize=(7, 7),  tglitch_ratio=False,
753
754
755
756
757
758
759
760
761
                    add_prior=False, nstds=None, label_offset=0.4,
                    dpi=300, rc_context={}, **kwargs):

        with plt.rc_context(rc_context):
            fig, axes = plt.subplots(self.ndim, self.ndim,
                                     figsize=figsize)

            samples_plt = copy.copy(self.samples)
            theta_symbols_plt = copy.copy(self.theta_symbols)
762
763
            theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}')
                                 for s in theta_symbols_plt]
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812

            if tglitch_ratio:
                for j, k in enumerate(self.theta_keys):
                    if k == 'tglitch':
                        s = samples_plt[:, j]
                        samples_plt[:, j] = (s - self.tstart)/(
                                             self.tend - self.tstart)
                        theta_symbols_plt[j] = r'$R_{\textrm{glitch}}$'

            if type(nstds) is int and 'range' not in kwargs:
                _range = []
                for j, s in enumerate(samples_plt.T):
                    median = np.median(s)
                    std = np.std(s)
                    _range.append((median - nstds*std, median + nstds*std))
            else:
                _range = None

            fig_triangle = corner.corner(samples_plt,
                                         labels=theta_symbols_plt,
                                         fig=fig,
                                         bins=50,
                                         max_n_ticks=4,
                                         plot_contours=True,
                                         plot_datapoints=True,
                                         label_kwargs={'fontsize': 8},
                                         data_kwargs={'alpha': 0.1,
                                                      'ms': 0.5},
                                         range=_range,
                                         **kwargs)

            axes_list = fig_triangle.get_axes()
            axes = np.array(axes_list).reshape(self.ndim, self.ndim)
            plt.draw()
            for ax in axes[:, 0]:
                ax.yaxis.set_label_coords(-label_offset, 0.5)
            for ax in axes[-1, :]:
                ax.xaxis.set_label_coords(0.5, -label_offset)
            for ax in axes_list:
                ax.set_rasterized(True)
                ax.set_rasterization_zorder(-10)
            plt.tight_layout(h_pad=0.0, w_pad=0.0)
            fig.subplots_adjust(hspace=0.05, wspace=0.05)

            if add_prior:
                self.add_prior_to_corner(axes, samples_plt)

            fig_triangle.savefig('{}/{}_corner.png'.format(
                self.outdir, self.label), dpi=dpi)
813
814
815
816
817
818

    def add_prior_to_corner(self, axes, samples):
        for i, key in enumerate(self.theta_keys):
            ax = axes[i][i]
            xlim = ax.get_xlim()
            s = samples[:, i]
Gregory Ashton's avatar
Gregory Ashton committed
819
            prior = self.generic_lnprior(**self.theta_prior[key])
820
821
822
823
824
825
            x = np.linspace(s.min(), s.max(), 100)
            ax2 = ax.twinx()
            ax2.get_yaxis().set_visible(False)
            ax2.plot(x, [prior(xi) for xi in x], '-r')
            ax.set_xlim(xlim)

826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
    def plot_prior_posterior(self, normal_stds=2):
        """ Plot the posterior in the context of the prior """
        fig, axes = plt.subplots(nrows=self.ndim, figsize=(8, 4*self.ndim))
        N = 1000
        from scipy.stats import gaussian_kde

        for i, (ax, key) in enumerate(zip(axes, self.theta_keys)):
            prior_dict = self.theta_prior[key]
            prior_func = self.generic_lnprior(**prior_dict)
            if prior_dict['type'] == 'unif':
                x = np.linspace(prior_dict['lower'], prior_dict['upper'], N)
                prior = prior_func(x)
                prior[0] = 0
                prior[-1] = 0
            elif prior_dict['type'] == 'norm':
                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = prior_func(x)
            else:
                raise ValueError('Not implemented for prior type {}'.format(
                    prior_dict['type']))
            priorln = ax.plot(x, prior, 'r', label='prior')
            ax.set_xlabel(self.theta_symbols[i])

            s = self.samples[:, i]
            while len(s) > 10**4:
                # random downsample to avoid slow calculation of kde
                s = np.random.choice(s, size=int(len(s)/2.))
            kde = gaussian_kde(s)
            ax2 = ax.twinx()
            postln = ax2.plot(x, kde.pdf(x), 'k', label='posterior')
            ax2.set_yticklabels([])
            ax.set_yticklabels([])

        lns = priorln + postln
        labs = [l.get_label() for l in lns]
        axes[0].legend(lns, labs, loc=1, framealpha=0.8)

        fig.savefig('{}/{}_prior_posterior.png'.format(
            self.outdir, self.label))

Gregory Ashton's avatar
Gregory Ashton committed
868
    def generic_lnprior(self, **kwargs):
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
        """ Return a lambda function of the pdf

        Parameters
        ----------
        kwargs: dict
            A dictionary containing 'type' of pdf and shape parameters

        """

        def logunif(x, a, b):
            above = x < b
            below = x > a
            if type(above) is not np.ndarray:
                if above and below:
                    return -np.log(b-a)
                else:
                    return -np.inf
            else:
                idxs = np.array([all(tup) for tup in zip(above, below)])
                p = np.zeros(len(x)) - np.inf
                p[idxs] = -np.log(b-a)
                return p

        def halfnorm(x, loc, scale):
            if x < 0:
                return -np.inf
            else:
                return -0.5*((x-loc)**2/scale**2+np.log(0.5*np.pi*scale**2))

        def cauchy(x, x0, gamma):
            return 1.0/(np.pi*gamma*(1+((x-x0)/gamma)**2))

        def exp(x, x0, gamma):
            if x > x0:
                return np.log(gamma) - gamma*(x - x0)
            else:
                return -np.inf

        if kwargs['type'] == 'unif':
            return lambda x: logunif(x, kwargs['lower'], kwargs['upper'])
        elif kwargs['type'] == 'halfnorm':
            return lambda x: halfnorm(x, kwargs['loc'], kwargs['scale'])
911
912
        elif kwargs['type'] == 'neghalfnorm':
            return lambda x: halfnorm(-x, kwargs['loc'], kwargs['scale'])
913
914
915
916
917
918
919
        elif kwargs['type'] == 'norm':
            return lambda x: -0.5*((x - kwargs['loc'])**2/kwargs['scale']**2
                                   + np.log(2*np.pi*kwargs['scale']**2))
        else:
            logging.info("kwargs:", kwargs)
            raise ValueError("Print unrecognise distribution")

Gregory Ashton's avatar
Gregory Ashton committed
920
    def generate_rv(self, **kwargs):
921
922
923
924
925
926
927
928
        dist_type = kwargs.pop('type')
        if dist_type == "unif":
            return np.random.uniform(low=kwargs['lower'], high=kwargs['upper'])
        if dist_type == "norm":
            return np.random.normal(loc=kwargs['loc'], scale=kwargs['scale'])
        if dist_type == "halfnorm":
            return np.abs(np.random.normal(loc=kwargs['loc'],
                                           scale=kwargs['scale']))
929
930
931
        if dist_type == "neghalfnorm":
            return -1 * np.abs(np.random.normal(loc=kwargs['loc'],
                                                scale=kwargs['scale']))
932
933
934
935
936
937
        if dist_type == "lognorm":
            return np.random.lognormal(
                mean=kwargs['loc'], sigma=kwargs['scale'])
        else:
            raise ValueError("dist_type {} unknown".format(dist_type))

Gregory Ashton's avatar
Gregory Ashton committed
938
    def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
Gregory Ashton's avatar
Gregory Ashton committed
939
                     burnin_idx=None):
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
        """ Plot all the chains from a sampler """

        shape = sampler.chain.shape
        if len(shape) == 3:
            nwalkers, nsteps, ndim = shape
            chain = sampler.chain[:, :, :]
        if len(shape) == 4:
            ntemps, nwalkers, nsteps, ndim = shape
            if temp < ntemps:
                logging.info("Plotting temperature {} chains".format(temp))
            else:
                raise ValueError(("Requested temperature {} outside of"
                                  "available range").format(temp))
            chain = sampler.chain[temp, :, :, :]

        with plt.style.context(('classic')):
Gregory Ashton's avatar
Gregory Ashton committed
956
957
958
959
            fig = plt.figure(figsize=(8, 4*ndim))
            ax = fig.add_subplot(ndim+1, 1, 1)
            axes = [ax] + [fig.add_subplot(ndim+1, 1, i, sharex=ax)
                           for i in range(2, ndim+1)]
960

Gregory Ashton's avatar
Gregory Ashton committed
961
            idxs = np.arange(chain.shape[1])
962
963
            if ndim > 1:
                for i in range(ndim):
964
                    axes[i].ticklabel_format(useOffset=False, axis='y')
Gregory Ashton's avatar
Gregory Ashton committed
965
966
967
968
969
970
                    cs = chain[:, :, i].T
                    if burnin_idx:
                        axes[i].plot(idxs[:burnin_idx], cs[:burnin_idx],
                                     color="r", alpha=alpha)
                    axes[i].plot(idxs[burnin_idx:], cs[burnin_idx:], color="k",
                                 alpha=alpha)
971
972
                    if symbols:
                        axes[i].set_ylabel(symbols[i])
973
            else:
Gregory Ashton's avatar
Gregory Ashton committed
974
                cs = chain[:, :, temp].T
975
976
                axes.plot(cs, color='k', alpha=alpha)
                axes.ticklabel_format(useOffset=False, axis='y')
977

Gregory Ashton's avatar
Gregory Ashton committed
978
979
980
        axes.append(fig.add_subplot(ndim+1, 1, ndim+1))
        lnl = sampler.lnlikelihood[temp, :, :]
        if burnin_idx:
Gregory Ashton's avatar
Gregory Ashton committed
981
982
            axes[-1].hist(lnl[:, :burnin_idx].flatten(), bins=50,
                          histtype='step', color='r')
Gregory Ashton's avatar
Gregory Ashton committed
983
984
        axes[-1].hist(lnl[:, burnin_idx:].flatten(), bins=50, histtype='step',
                      color='k')
Gregory Ashton's avatar
Gregory Ashton committed
985
986
987
988
        if self.BSGL:
            axes[-1].set_xlabel(r'$\mathcal{B}_\mathrm{S/GL}$')
        else:
            axes[-1].set_xlabel(r'$2\mathcal{F}$')
Gregory Ashton's avatar
Gregory Ashton committed
989

990
991
        return fig, axes

Gregory Ashton's avatar
Gregory Ashton committed
992
993
994
995
996
    def apply_corrections_to_p0(self, p0):
        """ Apply any correction to the initial p0 values """
        return p0

    def generate_scattered_p0(self, p):
997
        """ Generate a set of p0s scattered about p """
Gregory Ashton's avatar
Gregory Ashton committed
998
        p0 = [[p + self.scatter_val * p * np.random.randn(self.ndim)
999
1000
               for i in xrange(self.nwalkers)]
              for j in xrange(self.ntemps)]