pyfstat.py 113 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
""" Classes for various types of searches using ComputeFstatistic """
import os
import sys
import itertools
import logging
import argparse
import copy
import glob
import inspect
from functools import wraps
11
import subprocess
12
from collections import OrderedDict
13
14
15

import numpy as np
import matplotlib
16
matplotlib.use('Agg')
17
import matplotlib.pyplot as plt
18
import scipy.special
19
20
21
import emcee
import corner
import dill as pickle
22
import lal
23
24
import lalpulsar

25
26
27
28
29
30
try:
    from tqdm import tqdm
except ImportError:
    def tqdm(x):
        return x

31
plt.rcParams['text.usetex'] = True
32
plt.rcParams['axes.formatter.useoffset'] = False
33

34
35
36
37
38
39
40
config_file = os.path.expanduser('~')+'/.pyfstat.conf'
if os.path.isfile(config_file):
    d = {}
    with open(config_file, 'r') as f:
        for line in f:
            k, v = line.split('=')
            k = k.replace(' ', '')
41
            v = v.replace(' ', '').replace("'", "").replace('"', '').replace('\n', '')
42
43
44
45
46
47
            d[k] = v
    earth_ephem = d['earth_ephem']
    sun_ephem = d['sun_ephem']
else:
    logging.warning('No ~/.pyfstat.conf file found please provide the paths '
                    'when initialising searches')
48
49
50
    earth_ephem = None
    sun_ephem = None

51
52
53
54
55
parser = argparse.ArgumentParser()
parser.add_argument("-q", "--quite", help="Decrease output verbosity",
                    action="store_true")
parser.add_argument("-c", "--clean", help="Don't use cached data",
                    action="store_true")
56
parser.add_argument("-u", "--use-old-data", action="store_true")
57
parser.add_argument('-s', "--setup-only", action="store_true")
58
parser.add_argument('-n', "--no-template-counting", action="store_true")
59
60
61
62
parser.add_argument('unittest_args', nargs='*')
args, unknown = parser.parse_known_args()
sys.argv[1:] = args.unittest_args

Gregory Ashton's avatar
Gregory Ashton committed
63
64
65
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
66
if args.quite:
Gregory Ashton's avatar
Gregory Ashton committed
67
    stream_handler.setLevel(logging.WARNING)
68
else:
Gregory Ashton's avatar
Gregory Ashton committed
69
70
71
72
    stream_handler.setLevel(logging.DEBUG)
stream_handler.setFormatter(logging.Formatter(
    '%(asctime)s %(levelname)-8s: %(message)s', datefmt='%H:%M'))
logger.addHandler(stream_handler)
73

74

75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def round_to_n(x, n):
    if not x:
        return 0
    power = -int(np.floor(np.log10(abs(x)))) + (n - 1)
    factor = (10 ** power)
    return round(x * factor) / factor


def texify_float(x, d=1):
    x = round_to_n(x, d)
    if 0.01 < abs(x) < 100:
        return str(x)
    else:
        power = int(np.floor(np.log10(abs(x))))
        stem = np.round(x / 10**power, d)
        if d == 1:
            stem = int(stem)
        return r'${}{{\times}}10^{{{}}}$'.format(stem, power)


95
def initializer(func):
96
    """ Decorator function to automatically assign the parameters to self """
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    names, varargs, keywords, defaults = inspect.getargspec(func)

    @wraps(func)
    def wrapper(self, *args, **kargs):
        for name, arg in list(zip(names[1:], args)) + list(kargs.items()):
            setattr(self, name, arg)

        for name, default in zip(reversed(names), reversed(defaults)):
            if not hasattr(self, name):
                setattr(self, name, default)

        func(self, *args, **kargs)

    return wrapper


def read_par(label, outdir):
114
    """ Read in a .par file, returns a dictionary of the values """
115
116
117
118
    filename = '{}/{}.par'.format(outdir, label)
    d = {}
    with open(filename, 'r') as f:
        for line in f:
119
120
121
            if len(line.split('=')) > 1:
                key, val = line.rstrip('\n').split(' = ')
                key = key.strip()
122
                d[key] = np.float64(eval(val.rstrip('; ')))
123
124
125
126
    return d


class BaseSearchClass(object):
127
    """ The base search class, provides general functions """
128
129
130
131

    earth_ephem_default = earth_ephem
    sun_ephem_default = sun_ephem

132
    def add_log_file(self):
133
        """ Log output to a file, requires class to have outdir and label """
134
135
        logfilename = '{}/{}.log'.format(self.outdir, self.label)
        fh = logging.FileHandler(logfilename)
Gregory Ashton's avatar
Gregory Ashton committed
136
        fh.setLevel(logging.INFO)
137
138
139
140
141
        fh.setFormatter(logging.Formatter(
            '%(asctime)s %(levelname)-8s: %(message)s',
            datefmt='%y-%m-%d %H:%M'))
        logging.getLogger().addHandler(fh)

142
    def shift_matrix(self, n, dT):
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        """ Generate the shift matrix

        Parameters
        ----------
        n: int
            The dimension of the shift-matrix to generate
        dT: float
            The time delta of the shift matrix

        Returns
        -------
        m: array (n, n)
            The shift matrix
        """

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        m = np.zeros((n, n))
        factorial = np.math.factorial
        for i in range(n):
            for j in range(n):
                if i == j:
                    m[i, j] = 1.0
                elif i > j:
                    m[i, j] = 0.0
                else:
                    if i == 0:
                        m[i, j] = 2*np.pi*float(dT)**(j-i) / factorial(j-i)
                    else:
                        m[i, j] = float(dT)**(j-i) / factorial(j-i)
        return m

    def shift_coefficients(self, theta, dT):
        """ Shift a set of coefficients by dT

        Parameters
        ----------
        theta: array-like, shape (n,)
            vector of the expansion coefficients to transform starting from the
180
            lowest degree e.g [phi, F0, F1,...].
181
        dT: float
182
            difference between the two reference times as tref_new - tref_old.
183
184
185
186

        Returns
        -------
        theta_new: array-like shape (n,)
187
            vector of the coefficients as evaluate as the new reference time.
188
        """
189

190
191
192
193
        n = len(theta)
        m = self.shift_matrix(n, dT)
        return np.dot(m, theta)

194
    def calculate_thetas(self, theta, delta_thetas, tbounds, theta0_idx=0):
195
196
197
        """ Calculates the set of coefficients for the post-glitch signal """
        thetas = [theta]
        for i, dt in enumerate(delta_thetas):
198
199
200
201
202
203
204
205
206
207
208
209
210
            if i < theta0_idx:
                pre_theta_at_ith_glitch = self.shift_coefficients(
                    thetas[0], tbounds[i+1] - self.tref)
                post_theta_at_ith_glitch = pre_theta_at_ith_glitch - dt
                thetas.insert(0, self.shift_coefficients(
                    post_theta_at_ith_glitch, self.tref - tbounds[i+1]))

            elif i >= theta0_idx:
                pre_theta_at_ith_glitch = self.shift_coefficients(
                    thetas[i], tbounds[i+1] - self.tref)
                post_theta_at_ith_glitch = pre_theta_at_ith_glitch + dt
                thetas.append(self.shift_coefficients(
                    post_theta_at_ith_glitch, self.tref - tbounds[i+1]))
211
212
        return thetas

Gregory Ashton's avatar
Gregory Ashton committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    def generate_loudest(self):
        params = read_par(self.label, self.outdir)
        for key in ['Alpha', 'Delta', 'F0', 'F1']:
            if key not in params:
                params[key] = self.theta_prior[key]
        cmd = ('lalapps_ComputeFstatistic_v2 -a {} -d {} -f {} -s {} -D "{}"'
               ' --refTime={} --outputLoudest="{}/{}.loudest" '
               '--minStartTime={} --maxStartTime={}').format(
                    params['Alpha'], params['Delta'], params['F0'],
                    params['F1'], self.sftfilepath, params['tref'],
                    self.outdir, self.label, self.minStartTime,
                    self.maxStartTime)
        subprocess.call([cmd], shell=True)

227

Gregory Ashton's avatar
Gregory Ashton committed
228
class ComputeFstat(object):
229
    """ Base class providing interface to `lalpulsar.ComputeFstat` """
Gregory Ashton's avatar
Gregory Ashton committed
230
231
232
233
234

    earth_ephem_default = earth_ephem
    sun_ephem_default = sun_ephem

    @initializer
235
236
    def __init__(self, tref, sftfilepath=None, minStartTime=None,
                 maxStartTime=None, binary=False, transient=True, BSGL=False,
237
                 detector=None, minCoverFreq=None, maxCoverFreq=None,
238
                 earth_ephem=None, sun_ephem=None, injectSources=None
239
                 ):
240
241
242
243
244
        """
        Parameters
        ----------
        tref: int
            GPS seconds of the reference time.
245
246
        sftfilepath: str
            File patern to match SFTs
247
248
249
250
251
252
253
254
255
256
257
258
        minStartTime, maxStartTime: float GPStime
            Only use SFTs with timestemps starting from (including, excluding)
            this epoch
        binary: bool
            If true, search of binary parameters.
        transient: bool
            If true, allow for the Fstat to be computed over a transient range.
        BSGL: bool
            If true, compute the BSGL rather than the twoF value.
        detector: str
            Two character reference to the data to use, specify None for no
            contraint.
259
260
261
262
263
264
265
266
267
268
        minCoverFreq, maxCoverFreq: float
            The min and max cover frequency passed to CreateFstatInput, if
            either is None the range of frequencies in the SFT less 1Hz is
            used.
        earth_ephem, sun_ephem: str
            Paths of the two files containing positions of Earth and Sun,
            respectively at evenly spaced times, as passed to CreateFstatInput.
            If None defaults defined in BaseSearchClass will be used.

        """
Gregory Ashton's avatar
Gregory Ashton committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283

        if earth_ephem is None:
            self.earth_ephem = self.earth_ephem_default
        if sun_ephem is None:
            self.sun_ephem = self.sun_ephem_default

        self.init_computefstatistic_single_point()

    def init_computefstatistic_single_point(self):
        """ Initilisation step of run_computefstatistic for a single point """

        logging.info('Initialising SFTCatalog')
        constraints = lalpulsar.SFTConstraints()
        if self.detector:
            constraints.detector = self.detector
284
285
286
287
288
        if self.minStartTime:
            constraints.minStartTime = lal.LIGOTimeGPS(self.minStartTime)
        if self.maxStartTime:
            constraints.maxStartTime = lal.LIGOTimeGPS(self.maxStartTime)

289
        logging.info('Loading data matching pattern {}'.format(
290
291
                     self.sftfilepath))
        SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepath, constraints)
Gregory Ashton's avatar
Gregory Ashton committed
292
        names = list(set([d.header.name for d in SFTCatalog.data]))
293
        self.names = names
294
        SFT_timestamps = [d.header.epoch for d in SFTCatalog.data]
Gregory Ashton's avatar
Gregory Ashton committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        try:
            from bashplotlib.histogram import plot_hist
            print('Data timestamps histogram:')
            plot_hist(SFT_timestamps, height=5, bincount=50)
        except IOError:
            pass
        if len(names) == 0:
            raise ValueError('No data loaded.')
        logging.info('Loaded {} data files from detectors {}'.format(
            len(SFT_timestamps), names))
        logging.info('Data spans from {} ({}) to {} ({})'.format(
            int(SFT_timestamps[0]),
            subprocess.check_output('lalapps_tconvert {}'.format(
                int(SFT_timestamps[0])), shell=True).rstrip('\n'),
            int(SFT_timestamps[-1]),
            subprocess.check_output('lalapps_tconvert {}'.format(
311
                int(SFT_timestamps[-1])), shell=True).rstrip('\n')))
Gregory Ashton's avatar
Gregory Ashton committed
312
313
314
315
316
317

        logging.info('Initialising ephems')
        ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem)

        logging.info('Initialising FstatInput')
        dFreq = 0
318
319
320
321
322
        if self.transient:
            self.whatToCompute = lalpulsar.FSTATQ_ATOMS_PER_DET
        else:
            self.whatToCompute = lalpulsar.FSTATQ_2F

323
324
325
326
327
328
329
330
331
332
333
        FstatOAs = lalpulsar.FstatOptionalArgs()
        FstatOAs.randSeed = lalpulsar.FstatOptionalArgsDefaults.randSeed
        FstatOAs.SSBprec = lalpulsar.FstatOptionalArgsDefaults.SSBprec
        FstatOAs.Dterms = lalpulsar.FstatOptionalArgsDefaults.Dterms
        FstatOAs.runningMedianWindow = lalpulsar.FstatOptionalArgsDefaults.runningMedianWindow
        FstatOAs.FstatMethod = lalpulsar.FstatOptionalArgsDefaults.FstatMethod
        FstatOAs.InjectSqrtSX = lalpulsar.FstatOptionalArgsDefaults.injectSqrtSX
        FstatOAs.assumeSqrtSX = lalpulsar.FstatOptionalArgsDefaults.assumeSqrtSX
        FstatOAs.prevInput = lalpulsar.FstatOptionalArgsDefaults.prevInput
        FstatOAs.collectTiming = lalpulsar.FstatOptionalArgsDefaults.collectTiming

334
        if hasattr(self, 'injectSource') and type(self.injectSources) == dict:
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
            logging.info('Injecting source with params: {}'.format(
                self.injectSources))
            PPV = lalpulsar.CreatePulsarParamsVector(1)
            PP = PPV.data[0]
            PP.Amp.h0 = self.injectSources['h0']
            PP.Amp.cosi = self.injectSources['cosi']
            PP.Amp.phi0 = self.injectSources['phi0']
            PP.Amp.psi = self.injectSources['psi']
            PP.Doppler.Alpha = self.injectSources['Alpha']
            PP.Doppler.Delta = self.injectSources['Delta']
            PP.Doppler.fkdot = np.array(self.injectSources['fkdot'])
            PP.Doppler.refTime = self.tref
            if 't0' not in self.injectSources:
                #PP.Transient.t0 = int(self.minStartTime)
                #PP.Transient.tau = int(self.maxStartTime - self.minStartTime)
                PP.Transient.type = lalpulsar.TRANSIENT_NONE
            FstatOAs.injectSources = PPV
        else:
            FstatOAs.injectSources = lalpulsar.FstatOptionalArgsDefaults.injectSources
Gregory Ashton's avatar
Gregory Ashton committed
354
355

        if self.minCoverFreq is None or self.maxCoverFreq is None:
Gregory Ashton's avatar
Gregory Ashton committed
356
357
358
359
360
            fAs = [d.header.f0 for d in SFTCatalog.data]
            fBs = [d.header.f0 + (d.numBins-1)*d.header.deltaF
                   for d in SFTCatalog.data]
            self.minCoverFreq = np.min(fAs) + 0.5
            self.maxCoverFreq = np.max(fBs) - 0.5
361
362
363
            logging.info('Min/max cover freqs not provided, using '
                         '{} and {}, est. from SFTs'.format(
                             self.minCoverFreq, self.maxCoverFreq))
Gregory Ashton's avatar
Gregory Ashton committed
364
365
366
367
368
369

        self.FstatInput = lalpulsar.CreateFstatInput(SFTCatalog,
                                                     self.minCoverFreq,
                                                     self.maxCoverFreq,
                                                     dFreq,
                                                     ephems,
370
                                                     FstatOAs
Gregory Ashton's avatar
Gregory Ashton committed
371
372
373
374
375
376
377
378
379
380
381
382
383
                                                     )

        logging.info('Initialising PulsarDoplerParams')
        PulsarDopplerParams = lalpulsar.PulsarDopplerParams()
        PulsarDopplerParams.refTime = self.tref
        PulsarDopplerParams.Alpha = 1
        PulsarDopplerParams.Delta = 1
        PulsarDopplerParams.fkdot = np.array([0, 0, 0, 0, 0, 0, 0])
        self.PulsarDopplerParams = PulsarDopplerParams

        logging.info('Initialising FstatResults')
        self.FstatResults = lalpulsar.FstatResults()

384
        if self.BSGL:
Gregory Ashton's avatar
Gregory Ashton committed
385
386
            if len(names) < 2:
                raise ValueError("Can't use BSGL with single detector data")
387
            else:
388
                logging.info('Initialising BSGL')
389

390
391
            # Tuning parameters - to be reviewed
            numDetectors = 2
392
393
394
395
396
397
398
399
400
401
            if hasattr(self, 'nsegs'):
                p_val_threshold = 1e-6
                Fstar0s = np.linspace(0, 1000, 10000)
                p_vals = scipy.special.gammaincc(2*self.nsegs, Fstar0s)
                Fstar0 = Fstar0s[np.argmin(np.abs(p_vals - p_val_threshold))]
                if Fstar0 == Fstar0s[-1]:
                    raise ValueError('Max Fstar0 exceeded')
            else:
                Fstar0 = 15.
            logging.info('Using Fstar0 of {:1.2f}'.format(Fstar0))
402
            oLGX = np.zeros(10)
Gregory Ashton's avatar
Gregory Ashton committed
403
            oLGX[:numDetectors] = 1./numDetectors
404
            self.BSGLSetup = lalpulsar.CreateBSGLSetup(numDetectors,
405
                                                       Fstar0,
406
                                                       oLGX,
407
                                                       True,
408
409
                                                       1)
            self.twoFX = np.zeros(10)
Gregory Ashton's avatar
Gregory Ashton committed
410
            self.whatToCompute = (self.whatToCompute +
411
412
                                  lalpulsar.FSTATQ_2F_PER_DET)

413
        if self.transient:
414
            logging.info('Initialising transient parameters')
415
416
417
418
419
420
            self.windowRange = lalpulsar.transientWindowRange_t()
            self.windowRange.type = lalpulsar.TRANSIENT_RECTANGULAR
            self.windowRange.t0Band = 0
            self.windowRange.dt0 = 1
            self.windowRange.tauBand = 0
            self.windowRange.dtau = 1
421

422
423
424
425
426
427
428
429
430
    def compute_fullycoherent_det_stat_single_point(
            self, F0, F1, F2, Alpha, Delta, asini=None, period=None, ecc=None,
            tp=None, argp=None):
        """ Compute the fully-coherent det. statistic at a single point """

        return self.run_computefstatistic_single_point(
            self.minStartTime, self.maxStartTime, F0, F1, F2, Alpha, Delta,
            asini, period, ecc, tp, argp)

Gregory Ashton's avatar
Gregory Ashton committed
431
    def run_computefstatistic_single_point(self, tstart, tend, F0, F1,
432
433
434
                                           F2, Alpha, Delta, asini=None,
                                           period=None, ecc=None, tp=None,
                                           argp=None):
435
        """ Returns twoF or ln(BSGL) fully-coherently at a single point """
Gregory Ashton's avatar
Gregory Ashton committed
436
437
438
439

        self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0])
        self.PulsarDopplerParams.Alpha = Alpha
        self.PulsarDopplerParams.Delta = Delta
440
441
442
443
444
445
        if self.binary:
            self.PulsarDopplerParams.asini = asini
            self.PulsarDopplerParams.period = period
            self.PulsarDopplerParams.ecc = ecc
            self.PulsarDopplerParams.tp = tp
            self.PulsarDopplerParams.argp = argp
Gregory Ashton's avatar
Gregory Ashton committed
446
447
448
449

        lalpulsar.ComputeFstat(self.FstatResults,
                               self.FstatInput,
                               self.PulsarDopplerParams,
450
                               1,
Gregory Ashton's avatar
Gregory Ashton committed
451
452
453
                               self.whatToCompute
                               )

454
        if self.transient is False:
455
456
457
458
459
460
            if self.BSGL is False:
                return self.FstatResults.twoF[0]

            twoF = np.float(self.FstatResults.twoF[0])
            self.twoFX[0] = self.FstatResults.twoFPerDet(0)
            self.twoFX[1] = self.FstatResults.twoFPerDet(1)
461
462
463
            log10_BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX,
                                               self.BSGLSetup)
            return log10_BSGL/np.log10(np.exp(1))
464

465
466
        self.windowRange.t0 = int(tstart)  # TYPE UINT4
        self.windowRange.tau = int(tend - tstart)  # TYPE UINT4
467

Gregory Ashton's avatar
Gregory Ashton committed
468
        FS = lalpulsar.ComputeTransientFstatMap(
469
            self.FstatResults.multiFatoms[0], self.windowRange, False)
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484

        if self.BSGL is False:
            return 2*FS.F_mn.data[0][0]

        FstatResults_single = copy.copy(self.FstatResults)
        FstatResults_single.lenth = 1
        FstatResults_single.data = self.FstatResults.multiFatoms[0].data[0]
        FS0 = lalpulsar.ComputeTransientFstatMap(
            FstatResults_single.multiFatoms[0], self.windowRange, False)
        FstatResults_single.data = self.FstatResults.multiFatoms[0].data[1]
        FS1 = lalpulsar.ComputeTransientFstatMap(
            FstatResults_single.multiFatoms[0], self.windowRange, False)

        self.twoFX[0] = 2*FS0.F_mn.data[0][0]
        self.twoFX[1] = 2*FS1.F_mn.data[0][0]
485
486
        log10_BSGL = lalpulsar.ComputeBSGL(
                2*FS.F_mn.data[0][0], self.twoFX, self.BSGLSetup)
487

488
        return log10_BSGL/np.log10(np.exp(1))
Gregory Ashton's avatar
Gregory Ashton committed
489

490
491
    def calculate_twoF_cumulative(self, F0, F1, F2, Alpha, Delta, asini=None,
                                  period=None, ecc=None, tp=None, argp=None,
492
493
                                  tstart=None, tend=None, npoints=1000,
                                  minfraction=0.01, maxfraction=1):
494
495
        """ Calculate the cumulative twoF along the obseration span """
        duration = tend - tstart
496
497
        tstart = tstart + minfraction*duration
        taus = np.linspace(minfraction*duration, maxfraction*duration, npoints)
498
        twoFs = []
Gregory Ashton's avatar
Gregory Ashton committed
499
500
501
        if self.transient is False:
            self.transient = True
            self.init_computefstatistic_single_point()
502
503
504
505
506
507
508
509
510
        for tau in taus:
            twoFs.append(self.run_computefstatistic_single_point(
                tstart=tstart, tend=tstart+tau, F0=F0, F1=F1, F2=F2,
                Alpha=Alpha, Delta=Delta, asini=asini, period=period, ecc=ecc,
                tp=tp, argp=argp))

        return taus, np.array(twoFs)

    def plot_twoF_cumulative(self, label, outdir, ax=None, c='k', savefig=True,
511
                             title=None, **kwargs):
512

513
514
515
516
517
518
        taus, twoFs = self.calculate_twoF_cumulative(**kwargs)
        if ax is None:
            fig, ax = plt.subplots()
        ax.plot(taus/86400., twoFs, label=label, color=c)
        ax.set_xlabel(r'Days from $t_{{\rm start}}={:.0f}$'.format(
            kwargs['tstart']))
Gregory Ashton's avatar
Gregory Ashton committed
519
520
521
522
        if self.BSGL:
            ax.set_ylabel(r'$\log_{10}(\mathrm{BSGL})_{\rm cumulative}$')
        else:
            ax.set_ylabel(r'$\widetilde{2\mathcal{F}}_{\rm cumulative}$')
523
        ax.set_xlim(0, taus[-1]/86400)
524
        ax.set_title(title)
525
526
        if savefig:
            plt.savefig('{}/{}_twoFcumulative.png'.format(outdir, label))
Gregory Ashton's avatar
Gregory Ashton committed
527
            return taus, twoFs
528
529
530
        else:
            return ax

Gregory Ashton's avatar
Gregory Ashton committed
531

532
533
534
535
536
537
538
class SemiCoherentSearch(BaseSearchClass, ComputeFstat):
    """ A semi-coherent search """

    @initializer
    def __init__(self, label, outdir, tref, nsegs=None, sftfilepath=None,
                 binary=False, BSGL=False, minStartTime=None,
                 maxStartTime=None, minCoverFreq=None, maxCoverFreq=None,
539
540
                 detector=None, earth_ephem=None, sun_ephem=None,
                 injectSources=None):
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        """
        Parameters
        ----------
        label, outdir: str
            A label and directory to read/write data from/to.
        tref, tstart, tend: int
            GPS seconds of the reference time, and start and end of the data.
        nsegs: int
            The (fixed) number of segments
        sftfilepath: str
            File patern to match SFTs

        For all other parameters, see pyfstat.ComputeFStat.
        """

        self.fs_file_name = "{}/{}_FS.dat".format(self.outdir, self.label)
        if self.earth_ephem is None:
            self.earth_ephem = self.earth_ephem_default
        if self.sun_ephem is None:
            self.sun_ephem = self.sun_ephem_default
        self.transient = True
        self.init_computefstatistic_single_point()
        self.init_semicoherent_parameters()

    def init_semicoherent_parameters(self):
566
567
568
        logging.info(('Initialising semicoherent parameters from {} to {} in'
                      ' {} segments').format(
            self.minStartTime, self.maxStartTime, self.nsegs))
569
570
        self.transient = True
        self.whatToCompute = lalpulsar.FSTATQ_2F+lalpulsar.FSTATQ_ATOMS_PER_DET
571
572
573
        self.tboundaries = np.linspace(self.minStartTime, self.maxStartTime,
                                       self.nsegs+1)

Gregory Ashton's avatar
Gregory Ashton committed
574
575
576
577
    def run_semi_coherent_computefstatistic_single_point(
            self, F0, F1, F2, Alpha, Delta, asini=None,
            period=None, ecc=None, tp=None, argp=None):
        """ Returns twoF or ln(BSGL) semi-coherently at a single point """
578

Gregory Ashton's avatar
Gregory Ashton committed
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0])
        self.PulsarDopplerParams.Alpha = Alpha
        self.PulsarDopplerParams.Delta = Delta
        if self.binary:
            self.PulsarDopplerParams.asini = asini
            self.PulsarDopplerParams.period = period
            self.PulsarDopplerParams.ecc = ecc
            self.PulsarDopplerParams.tp = tp
            self.PulsarDopplerParams.argp = argp

        lalpulsar.ComputeFstat(self.FstatResults,
                               self.FstatInput,
                               self.PulsarDopplerParams,
                               1,
                               self.whatToCompute
                               )

        if self.transient is False:
            if self.BSGL is False:
                return self.FstatResults.twoF[0]

            twoF = np.float(self.FstatResults.twoF[0])
            self.twoFX[0] = self.FstatResults.twoFPerDet(0)
            self.twoFX[1] = self.FstatResults.twoFPerDet(1)
            log10_BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX,
                                               self.BSGLSetup)
            return log10_BSGL/np.log10(np.exp(1))

        detStat = 0
        for tstart, tend in zip(self.tboundaries[:-1], self.tboundaries[1:]):
            self.windowRange.t0 = int(tstart)  # TYPE UINT4
            self.windowRange.tau = int(tend - tstart)  # TYPE UINT4

            FS = lalpulsar.ComputeTransientFstatMap(
                self.FstatResults.multiFatoms[0], self.windowRange, False)

            if self.BSGL is False:
                detStat += 2*FS.F_mn.data[0][0]
                continue
618

Gregory Ashton's avatar
Gregory Ashton committed
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
            FstatResults_single = copy.copy(self.FstatResults)
            FstatResults_single.lenth = 1
            FstatResults_single.data = self.FstatResults.multiFatoms[0].data[0]
            FS0 = lalpulsar.ComputeTransientFstatMap(
                FstatResults_single.multiFatoms[0], self.windowRange, False)
            FstatResults_single.data = self.FstatResults.multiFatoms[0].data[1]
            FS1 = lalpulsar.ComputeTransientFstatMap(
                FstatResults_single.multiFatoms[0], self.windowRange, False)

            self.twoFX[0] = 2*FS0.F_mn.data[0][0]
            self.twoFX[1] = 2*FS1.F_mn.data[0][0]
            log10_BSGL = lalpulsar.ComputeBSGL(
                    2*FS.F_mn.data[0][0], self.twoFX, self.BSGLSetup)

            detStat += log10_BSGL/np.log10(np.exp(1))

        return detStat
636
637


Gregory Ashton's avatar
Gregory Ashton committed
638
class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
639
640
641
    """ A semi-coherent glitch search

    This implements a basic `semi-coherent glitch F-stat in which the data
642
643
    is divided into segments either side of the proposed glitches and the
    fully-coherent F-stat in each segment is summed to give the semi-coherent
644
645
646
647
    F-stat
    """

    @initializer
Gregory Ashton's avatar
Gregory Ashton committed
648
    def __init__(self, label, outdir, tref, tstart, tend, nglitch=0,
649
650
651
                 sftfilepath=None, theta0_idx=0, BSGL=False, minStartTime=None,
                 maxStartTime=None, minCoverFreq=None, maxCoverFreq=None,
                 detector=None, earth_ephem=None, sun_ephem=None):
652
653
654
655
        """
        Parameters
        ----------
        label, outdir: str
656
657
658
659
660
661
            A label and directory to read/write data from/to.
        tref, tstart, tend: int
            GPS seconds of the reference time, and start and end of the data.
        nglitch: int
            The (fixed) number of glitches; this can zero, but occasionally
            this causes issue (in which case just use ComputeFstat).
662
663
        sftfilepath: str
            File patern to match SFTs
664
665
666
667
        theta0_idx, int
            Index (zero-based) of which segment the theta refers to - uyseful
            if providing a tight prior on theta to allow the signal to jump
            too theta (and not just from)
668
669

        For all other parameters, see pyfstat.ComputeFStat.
670
671
672
673
674
675
676
        """

        self.fs_file_name = "{}/{}_FS.dat".format(self.outdir, self.label)
        if self.earth_ephem is None:
            self.earth_ephem = self.earth_ephem_default
        if self.sun_ephem is None:
            self.sun_ephem = self.sun_ephem_default
677
678
        self.transient = True
        self.binary = False
679
680
681
        self.init_computefstatistic_single_point()

    def compute_nglitch_fstat(self, F0, F1, F2, Alpha, Delta, *args):
682
        """ Returns the semi-coherent glitch summed twoF """
683
684
685

        args = list(args)
        tboundaries = [self.tstart] + args[-self.nglitch:] + [self.tend]
686
687
688
689
690
691
692
693
        delta_F0s = args[-3*self.nglitch:-2*self.nglitch]
        delta_F1s = args[-2*self.nglitch:-self.nglitch]
        delta_F2 = np.zeros(len(delta_F0s))
        delta_phi = np.zeros(len(delta_F0s))
        theta = [0, F0, F1, F2]
        delta_thetas = np.atleast_2d(
                np.array([delta_phi, delta_F0s, delta_F1s, delta_F2]).T)

694
695
        thetas = self.calculate_thetas(theta, delta_thetas, tboundaries,
                                       theta0_idx=self.theta0_idx)
696
697

        twoFSum = 0
698
        for i, theta_i_at_tref in enumerate(thetas):
699
700
701
            ts, te = tboundaries[i], tboundaries[i+1]

            twoFVal = self.run_computefstatistic_single_point(
702
703
                ts, te, theta_i_at_tref[1], theta_i_at_tref[2],
                theta_i_at_tref[3], Alpha, Delta)
704
705
            twoFSum += twoFVal

706
707
708
        if np.isfinite(twoFSum):
            return twoFSum
        else:
709
            return -np.inf
710
711
712

    def compute_glitch_fstat_single(self, F0, F1, F2, Alpha, Delta, delta_F0,
                                    delta_F1, tglitch):
713
714
715
716
        """ Returns the semi-coherent glitch summed twoF for nglitch=1

        Note: used for testing
        """
717
718
719
720
721
722
723
724
725
726
727

        theta = [F0, F1, F2]
        delta_theta = [delta_F0, delta_F1, 0]
        tref = self.tref

        theta_at_glitch = self.shift_coefficients(theta, tglitch - tref)
        theta_post_glitch_at_glitch = theta_at_glitch + delta_theta
        theta_post_glitch = self.shift_coefficients(
            theta_post_glitch_at_glitch, tref - tglitch)

        twoFsegA = self.run_computefstatistic_single_point(
Gregory Ashton's avatar
Gregory Ashton committed
728
            self.tstart, tglitch, theta[0], theta[1], theta[2], Alpha,
729
730
731
732
733
734
            Delta)

        if tglitch == self.tend:
            return twoFsegA

        twoFsegB = self.run_computefstatistic_single_point(
Gregory Ashton's avatar
Gregory Ashton committed
735
            tglitch, self.tend, theta_post_glitch[0],
736
737
738
739
740
741
            theta_post_glitch[1], theta_post_glitch[2], Alpha,
            Delta)

        return twoFsegA + twoFsegB


Gregory Ashton's avatar
Gregory Ashton committed
742
743
class MCMCSearch(BaseSearchClass):
    """ MCMC search using ComputeFstat"""
744
    @initializer
745
    def __init__(self, label, outdir, sftfilepath, theta_prior, tref,
746
                 tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1,
747
                 log10temperature_min=-5, theta_initial=None, scatter_val=1e-10,
748
749
                 binary=False, BSGL=False, minCoverFreq=None,
                 maxCoverFreq=None, detector=None, earth_ephem=None,
750
                 sun_ephem=None, injectSources=None):
751
752
753
754
        """
        Parameters
        label, outdir: str
            A label and directory to read/write data from/to
755
756
        sftfilepath: str
            File patern to match SFTs
757
        theta_prior: dict
758
759
760
761
            Dictionary of priors and fixed values for the search parameters.
            For each parameters (key of the dict), if it is to be held fixed
            the value should be the constant float, if it is be searched, the
            value should be a dictionary of the prior.
762
763
764
765
        theta_initial: dict, array, (None)
            Either a dictionary of distribution about which to distribute the
            initial walkers about, an array (from which the walkers will be
            scattered by scatter_val, or  None in which case the prior is used.
766
767
768
769
770
771
772
        tref, tstart, tend: int
            GPS seconds of the reference time, start time and end time
        nsteps: list (m,)
            List specifying the number of steps to take, the last two entries
            give the nburn and nprod of the 'production' run, all entries
            before are for iterative initialisation steps (usually just one)
            e.g. [1000, 1000, 500].
773
774
775
776
777
778
779
780
781
782
783
        nwalkers, ntemps: int,
            The number of walkers and temperates to use in the parallel
            tempered PTSampler.
        log10temperature_min float < 0
            The  log_10(tmin) value, the set of betas passed to PTSampler are
            generated from np.logspace(0, log10temperature_min, ntemps).
        binary: Bool
            If true, search over binary parameters
        detector: str
            Two character reference to the data to use, specify None for no
            contraint.
784
785
786
787
788
789
790
791
792
793
        minCoverFreq, maxCoverFreq: float
            Minimum and maximum instantaneous frequency which will be covered
            over the SFT time span as passed to CreateFstatInput
        earth_ephem, sun_ephem: str
            Paths of the two files containing positions of Earth and Sun,
            respectively at evenly spaced times, as passed to CreateFstatInput
            If None defaults defined in BaseSearchClass will be used

        """

794
795
796
        self.minStartTime = tstart
        self.maxStartTime = tend

Gregory Ashton's avatar
Gregory Ashton committed
797
798
        if os.path.isdir(outdir) is False:
            os.mkdir(outdir)
799
        self.add_log_file()
Gregory Ashton's avatar
Gregory Ashton committed
800
801
        logging.info(
            'Set-up MCMC search for model {} on data {}'.format(
802
                self.label, self.sftfilepath))
803
804
805
        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
        self.unpack_input_theta()
        self.ndim = len(self.theta_keys)
806
807
808
809
        if self.log10temperature_min:
            self.betas = np.logspace(0, self.log10temperature_min, self.ntemps)
        else:
            self.betas = None
810

811
812
813
814
815
816
817
818
        if earth_ephem is None:
            self.earth_ephem = self.earth_ephem_default
        if sun_ephem is None:
            self.sun_ephem = self.sun_ephem_default

        if args.clean and os.path.isfile(self.pickle_path):
            os.rename(self.pickle_path, self.pickle_path+".old")

819
820
821
        self.log_input()

    def log_input(self):
822
        logging.info('theta_prior = {}'.format(self.theta_prior))
823
        logging.info('nwalkers={}'.format(self.nwalkers))
824
825
826
827
        logging.info('scatter_val = {}'.format(self.scatter_val))
        logging.info('nsteps = {}'.format(self.nsteps))
        logging.info('ntemps = {}'.format(self.ntemps))
        logging.info('log10temperature_min = {}'.format(
828
            self.log10temperature_min))
829
830
831

    def inititate_search_object(self):
        logging.info('Setting up search object')
Gregory Ashton's avatar
Gregory Ashton committed
832
        self.search = ComputeFstat(
833
834
835
836
            tref=self.tref, sftfilepath=self.sftfilepath,
            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
            earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
            detector=self.detector, BSGL=self.BSGL, transient=False,
837
            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
838
            binary=self.binary, injectSources=self.injectSources)
839
840

    def logp(self, theta_vals, theta_prior, theta_keys, search):
Gregory Ashton's avatar
Gregory Ashton committed
841
        H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
842
843
844
845
846
847
             zip(theta_vals, theta_keys)]
        return np.sum(H)

    def logl(self, theta, search):
        for j, theta_i in enumerate(self.theta_idxs):
            self.fixed_theta[theta_i] = theta[j]
848
849
        FS = search.compute_fullycoherent_det_stat_single_point(
            *self.fixed_theta)
850
851
852
        return FS

    def unpack_input_theta(self):
853
        full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']
854
855
856
        if self.binary:
            full_theta_keys += [
                'asini', 'period', 'ecc', 'tp', 'argp']
857
858
        full_theta_keys_copy = copy.copy(full_theta_keys)

859
860
        full_theta_symbols = ['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
                              r'$\delta$']
861
862
863
864
        if self.binary:
            full_theta_symbols += [
                'asini', 'period', 'period', 'ecc', 'tp', 'argp']

865
866
        self.theta_keys = []
        fixed_theta_dict = {}
867
        for key, val in self.theta_prior.iteritems():
868
869
            if type(val) is dict:
                fixed_theta_dict[key] = 0
Gregory Ashton's avatar
Gregory Ashton committed
870
                self.theta_keys.append(key)
871
872
873
874
875
876
            elif type(val) in [float, int, np.float64]:
                fixed_theta_dict[key] = val
            else:
                raise ValueError(
                    'Type {} of {} in theta not recognised'.format(
                        type(val), key))
Gregory Ashton's avatar
Gregory Ashton committed
877
            full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893

        if len(full_theta_keys_copy) > 0:
            raise ValueError(('Input dictionary `theta` is missing the'
                              'following keys: {}').format(
                                  full_theta_keys_copy))

        self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
        self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
        self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]

        idxs = np.argsort(self.theta_idxs)
        self.theta_idxs = [self.theta_idxs[i] for i in idxs]
        self.theta_symbols = [self.theta_symbols[i] for i in idxs]
        self.theta_keys = [self.theta_keys[i] for i in idxs]

    def check_initial_points(self, p0):
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
        for nt in range(self.ntemps):
            logging.info('Checking temperature {} chains'.format(nt))
            initial_priors = np.array([
                self.logp(p, self.theta_prior, self.theta_keys, self.search)
                for p in p0[nt]])
            number_of_initial_out_of_bounds = sum(initial_priors == -np.inf)

            if number_of_initial_out_of_bounds > 0:
                logging.warning(
                    'Of {} initial values, {} are -np.inf due to the prior'
                    .format(len(initial_priors),
                            number_of_initial_out_of_bounds))

                p0 = self.generate_new_p0_to_fix_initial_points(
                    p0, nt, initial_priors)

    def generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors):
        logging.info('Attempting to correct intial values')
        idxs = np.arange(self.nwalkers)[initial_priors == -np.inf]
        count = 0
        while sum(initial_priors == -np.inf) > 0 and count < 100:
            for j in idxs:
                p0[nt][j] = (p0[nt][np.random.randint(0, self.nwalkers)]*(
                             1+np.random.normal(0, 1e-10, self.ndim)))
            initial_priors = np.array([
                self.logp(p, self.theta_prior, self.theta_keys,
                          self.search)
                for p in p0[nt]])
            count += 1

        if sum(initial_priors == -np.inf) > 0:
            logging.info('Failed to fix initial priors')
        else:
            logging.info('Suceeded to fix initial priors')

        return p0
930

Gregory Ashton's avatar
Gregory Ashton committed
931
    def run_sampler_with_progress_bar(self, sampler, ns, p0):
932
933
        for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
            pass
Gregory Ashton's avatar
Gregory Ashton committed
934
935
        return sampler

936
    def run(self, proposal_scale_factor=2, **kwargs):
937

Gregory Ashton's avatar
Gregory Ashton committed
938
        self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
        if self.old_data_is_okay_to_use is True:
            logging.warning('Using saved data from {}'.format(
                self.pickle_path))
            d = self.get_saved_data()
            self.sampler = d['sampler']
            self.samples = d['samples']
            self.lnprobs = d['lnprobs']
            self.lnlikes = d['lnlikes']
            return

        self.inititate_search_object()

        sampler = emcee.PTSampler(
            self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
            logpargs=(self.theta_prior, self.theta_keys, self.search),
954
            loglargs=(self.search,), betas=self.betas, a=proposal_scale_factor)
955

Gregory Ashton's avatar
Gregory Ashton committed
956
957
        p0 = self.generate_initial_p0()
        p0 = self.apply_corrections_to_p0(p0)
958
959
960
961
962
        self.check_initial_points(p0)

        ninit_steps = len(self.nsteps) - 2
        for j, n in enumerate(self.nsteps[:-2]):
            logging.info('Running {}/{} initialisation with {} steps'.format(
Gregory Ashton's avatar
Gregory Ashton committed
963
                j, ninit_steps, n))
Gregory Ashton's avatar
Gregory Ashton committed
964
            sampler = self.run_sampler_with_progress_bar(sampler, n, p0)
965
966
            logging.info("Mean acceptance fraction: {}"
                         .format(np.mean(sampler.acceptance_fraction, axis=1)))
967
968
969
            if self.ntemps > 1:
                logging.info("Tswap acceptance fraction: {}"
                             .format(sampler.tswap_acceptance_fraction))
970
971
            fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
                                          **kwargs)
972
            fig.savefig('{}/{}_init_{}_walkers.png'.format(
973
                self.outdir, self.label, j), dpi=200)
974

975
            p0 = self.get_new_p0(sampler)
Gregory Ashton's avatar
Gregory Ashton committed
976
            p0 = self.apply_corrections_to_p0(p0)
977
978
979
            self.check_initial_points(p0)
            sampler.reset()

Gregory Ashton's avatar
Gregory Ashton committed
980
981
982
983
        if len(self.nsteps) > 1:
            nburn = self.nsteps[-2]
        else:
            nburn = 0
984
985
986
        nprod = self.nsteps[-1]
        logging.info('Running final burn and prod with {} steps'.format(
            nburn+nprod))
Gregory Ashton's avatar
Gregory Ashton committed
987
        sampler = self.run_sampler_with_progress_bar(sampler, nburn+nprod, p0)
988
989
        logging.info("Mean acceptance fraction: {}"
                     .format(np.mean(sampler.acceptance_fraction, axis=1)))
990
991
992
        if self.ntemps > 1:
            logging.info("Tswap acceptance fraction: {}"
                         .format(sampler.tswap_acceptance_fraction))
993

Gregory Ashton's avatar
Gregory Ashton committed
994
        fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
995
996
997
                                      burnin_idx=nburn, **kwargs)
        fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
                    dpi=200)
998
999
1000
1001
1002
1003
1004
1005
1006
1007

        samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
        lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1))
        lnlikes = sampler.lnlikelihood[0, :, nburn:].reshape((-1))
        self.sampler = sampler
        self.samples = samples
        self.lnprobs = lnprobs
        self.lnlikes = lnlikes
        self.save_data(sampler, samples, lnprobs, lnlikes)

1008
    def plot_corner(self, figsize=(7, 7),  tglitch_ratio=False,
1009
1010
1011
                    add_prior=False, nstds=None, label_offset=0.4,
                    dpi=300, rc_context={}, **kwargs):

Gregory Ashton's avatar
Gregory Ashton committed
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
        if self.ndim < 2:
            with plt.rc_context(rc_context):
                fig, ax = plt.subplots(figsize=figsize)
                ax.hist(self.samples, bins=50, histtype='stepfilled')
                ax.set_xlabel(self.theta_symbols[0])

            fig.savefig('{}/{}_corner.png'.format(
                self.outdir, self.label), dpi=dpi)
            return

1022
1023
1024
1025
1026
1027
        with plt.rc_context(rc_context):
            fig, axes = plt.subplots(self.ndim, self.ndim,
                                     figsize=figsize)

            samples_plt = copy.copy(self.samples)
            theta_symbols_plt = copy.copy(self.theta_symbols)
1028
1029
            theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}')
                                 for s in theta_symbols_plt]
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078

            if tglitch_ratio:
                for j, k in enumerate(self.theta_keys):
                    if k == 'tglitch':
                        s = samples_plt[:, j]
                        samples_plt[:, j] = (s - self.tstart)/(
                                             self.tend - self.tstart)
                        theta_symbols_plt[j] = r'$R_{\textrm{glitch}}$'

            if type(nstds) is int and 'range' not in kwargs:
                _range = []
                for j, s in enumerate(samples_plt.T):
                    median = np.median(s)
                    std = np.std(s)
                    _range.append((median - nstds*std, median + nstds*std))
            else:
                _range = None

            fig_triangle = corner.corner(samples_plt,
                                         labels=theta_symbols_plt,
                                         fig=fig,
                                         bins=50,
                                         max_n_ticks=4,
                                         plot_contours=True,
                                         plot_datapoints=True,
                                         label_kwargs={'fontsize': 8},
                                         data_kwargs={'alpha': 0.1,
                                                      'ms': 0.5},
                                         range=_range,
                                         **kwargs)

            axes_list = fig_triangle.get_axes()
            axes = np.array(axes_list).reshape(self.ndim, self.ndim)
            plt.draw()
            for ax in axes[:, 0]:
                ax.yaxis.set_label_coords(-label_offset, 0.5)
            for ax in axes[-1, :]:
                ax.xaxis.set_label_coords(0.5, -label_offset)
            for ax in axes_list:
                ax.set_rasterized(True)
                ax.set_rasterization_zorder(-10)
            plt.tight_layout(h_pad=0.0, w_pad=0.0)
            fig.subplots_adjust(hspace=0.05, wspace=0.05)

            if add_prior:
                self.add_prior_to_corner(axes, samples_plt)

            fig_triangle.savefig('{}/{}_corner.png'.format(
                self.outdir, self.label), dpi=dpi)
1079
1080
1081
1082
1083
1084

    def add_prior_to_corner(self, axes, samples):
        for i, key in enumerate(self.theta_keys):
            ax = axes[i][i]
            xlim = ax.get_xlim()
            s = samples[:, i]
Gregory Ashton's avatar
Gregory Ashton committed
1085
            prior = self.generic_lnprior(**self.theta_prior[key])
1086
1087
1088
1089
1090
1091
            x = np.linspace(s.min(), s.max(), 100)
            ax2 = ax.twinx()
            ax2.get_yaxis().set_visible(False)
            ax2.plot(x, [prior(xi) for xi in x], '-r')
            ax.set_xlim(xlim)

1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
    def plot_prior_posterior(self, normal_stds=2):
        """ Plot the posterior in the context of the prior """
        fig, axes = plt.subplots(nrows=self.ndim, figsize=(8, 4*self.ndim))
        N = 1000
        from scipy.stats import gaussian_kde

        for i, (ax, key) in enumerate(zip(axes, self.theta_keys)):
            prior_dict = self.theta_prior[key]
            prior_func = self.generic_lnprior(**prior_dict)
            if prior_dict['type'] == 'unif':
                x = np.linspace(prior_dict['lower'], prior_dict['upper'], N)
                prior = prior_func(x)
                prior[0] = 0
                prior[-1] = 0
            elif prior_dict['type'] == 'norm':
                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = prior_func(x)
1111
1112
1113
1114
1115
            elif prior_dict['type'] == 'halfnorm':
                lower = prior_dict['loc']
                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = [prior_func(xi) for xi in x]
Gregory Ashton's avatar
Gregory Ashton committed
1116
1117
1118
1119
1120
            elif prior_dict['type'] == 'neghalfnorm':
                upper = prior_dict['loc']
                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = [prior_func(xi) for xi in x]
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
            else:
                raise ValueError('Not implemented for prior type {}'.format(
                    prior_dict['type']))
            priorln = ax.plot(x, prior, 'r', label='prior')
            ax.set_xlabel(self.theta_symbols[i])

            s = self.samples[:, i]
            while len(s) > 10**4:
                # random downsample to avoid slow calculation of kde
                s = np.random.choice(s, size=int(len(s)/2.))
            kde = gaussian_kde(s)
            ax2 = ax.twinx()
            postln = ax2.plot(x, kde.pdf(x), 'k', label='posterior')
            ax2.set_yticklabels([])
            ax.set_yticklabels([])

        lns = priorln + postln
        labs = [l.get_label() for l in lns]
        axes[0].legend(lns, labs, loc=1, framealpha=0.8)

        fig.savefig('{}/{}_prior_posterior.png'.format(
            self.outdir, self.label))

1144
    def plot_cumulative_max(self, **kwargs):
Gregory Ashton's avatar
Gregory Ashton committed
1145
1146
1147
1148
        d, maxtwoF = self.get_max_twoF()
        for key, val in self.theta_prior.iteritems():
            if key not in d:
                d[key] = val
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162

        if hasattr(self, 'search') is False:
            self.inititate_search_object()
        if self.binary is False:
            self.search.plot_twoF_cumulative(
                self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'],
                Alpha=d['Alpha'], Delta=d['Delta'], tstart=self.tstart,
                tend=self.tend, **kwargs)
        else:
            self.search.plot_twoF_cumulative(
                self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'],
                Alpha=d['Alpha'], Delta=d['Delta'], asini=d['asini'],
                period=d['period'], ecc=d['ecc'], argp=d['argp'], tp=d['argp'],
                tstart=self.tstart, tend=self.tend, **kwargs)
Gregory Ashton's avatar
Gregory Ashton committed
1163

Gregory Ashton's avatar
Gregory Ashton committed
1164
    def generic_lnprior(self, **kwargs):
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
        """ Return a lambda function of the pdf

        Parameters
        ----------
        kwargs: dict
            A dictionary containing 'type' of pdf and shape parameters

        """

        def logunif(x, a, b):
            above = x < b
            below = x > a
            if type(above) is not np.ndarray:
                if above and below:
                    return -np.log(b-a)
                else:
                    return -np.inf
            else:
                idxs = np.array([all(tup) for tup in zip(above, below)])
                p = np.zeros(len(x)) - np.inf
                p[idxs] = -np.log(b-a)
                return p

        def halfnorm(x, loc, scale):
            if x < 0:
                return -np.inf
            else:
                return -0.5*((x-loc)**2/scale**2+np.log(0.5*np.pi*scale**2))

        def cauchy(x, x0, gamma):
            return 1.0/(np.pi*gamma*(1+((x-x0)/gamma)**2))

        def exp(x, x0, gamma):
            if x > x0:
                return np.log(gamma) - gamma*(x - x0)
            else:
                return -np.inf

        if kwargs['type'] == 'unif':
            return lambda x: logunif(x, kwargs['lower'], kwargs['upper'])
        elif kwargs['type'] == 'halfnorm':
            return lambda x: halfnorm(x, kwargs['loc'], kwargs['scale'])
1207
1208
        elif kwargs['type'] == 'neghalfnorm':
            return lambda x: halfnorm(-x, kwargs['loc'], kwargs['scale'])
1209
1210
1211
1212
1213
1214
1215
        elif kwargs['type'] == 'norm':
            return lambda x: -0.5*((x - kwargs['loc'])**2/kwargs['scale']**2
                                   + np.log(2*np.pi*kwargs['scale']**2))
        else:
            logging.info("kwargs:", kwargs)
            raise ValueError("Print unrecognise distribution")

Gregory Ashton's avatar
Gregory Ashton committed
1216
    def generate_rv(self, **kwargs):
1217
1218
1219
1220
1221
1222
1223
1224
        dist_type = kwargs.pop('type')
        if dist_type == "unif":
            return np.random.uniform(low=kwargs['lower'], high=kwargs['upper'])
        if dist_type == "norm":
            return np.random.normal(loc=kwargs['loc'], scale=kwargs['scale'])
        if dist_type == "halfnorm":
            return np.abs(np.random.normal(loc=kwargs['loc'],
                                           scale=kwargs['scale']))
1225
1226
1227
        if dist_type == "neghalfnorm":
            return -1 * np.abs(np.random.normal(loc=kwargs['loc'],
                                                scale=kwargs['scale']))
1228
1229
1230
1231
1232
1233
        if dist_type == "lognorm":
            return np.random.lognormal(
                mean=kwargs['loc'], sigma=kwargs['scale'])
        else:
            raise ValueError("dist_type {} unknown".format(dist_type))

Gregory Ashton's avatar
Gregory Ashton committed
1234
    def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
Gregory Ashton's avatar
Gregory Ashton committed
1235
                     lw=0.1, burnin_idx=None, add_det_stat_burnin=False,
1236
1237
                     fig=None, axes=None, xoffset=0, plot_det_stat=True,
                     context='classic'):
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
        """ Plot all the chains from a sampler """

        shape = sampler.chain.shape
        if len(shape) == 3:
            nwalkers, nsteps, ndim = shape
            chain = sampler.chain[:, :, :]
        if len(shape) == 4:
            ntemps, nwalkers, nsteps, ndim = shape
            if temp < ntemps:
                logging.info("Plotting temperature {} chains".format(temp))
            else:
                raise ValueError(("Requested temperature {} outside of"
                                  "available range").format(temp))
            chain = sampler.chain[temp, :, :, :]

1253
        with plt.style.context((context)):
Gregory Ashton's avatar
Gregory Ashton committed
1254
1255
1256
1257
1258
            if fig is None and axes is None:
                fig = plt.figure(figsize=(8, 4*ndim))
                ax = fig.add_subplot(ndim+1, 1, 1)
                axes = [ax] + [fig.add_subplot(ndim+1, 1, i, sharex=ax)
                               for i in range(2, ndim+1)]
1259

Gregory Ashton's avatar
Gregory Ashton committed
1260
            idxs = np.arange(chain.shape[1])
1261
1262
            if ndim > 1:
                for i in range(ndim):
1263
                    axes[i].ticklabel_format(useOffset=False, axis='y')
1264
1265
                    if i < ndim:
                        axes[i].set_xticklabels([])
Gregory Ashton's avatar
Gregory Ashton committed
1266
1267
                    cs = chain[:, :, i].T
                    if burnin_idx:
Gregory Ashton's avatar
Gregory Ashton committed
1268
1269
1270
1271
1272
                        axes[i].plot(xoffset+idxs[:burnin_idx],
                                     cs[:burnin_idx], color="r", alpha=alpha,
                                     lw=lw)
                    axes[i].plot(xoffset+idxs[burnin_idx:], cs[burnin_idx:],
                                 color="k", alpha=alpha, lw=lw)
1273
1274
                    if symbols:
                        axes[i].set_ylabel(symbols[i])
1275
            else:
Gregory Ashton's avatar
Gregory Ashton committed
1276
                axes[0].ticklabel_format(useOffset=False, axis='y')
Gregory Ashton's avatar
Gregory Ashton committed
1277
                cs = chain[:, :, temp].T
Gregory Ashton's avatar
Gregory Ashton committed
1278
1279
1280
1281
1282
1283
1284
                if burnin_idx:
                    axes[0].plot(idxs[:burnin_idx], cs[:burnin_idx],
                                 color="r", alpha=alpha, lw=lw)
                axes[0].plot(idxs[burnin_idx:], cs[burnin_idx:], color="k",
                             alpha=alpha, lw=lw)
                if symbols:
                    axes[0].set_ylabel(symbols[0])
1285

1286
1287
            if len(axes) == ndim:
                axes.append(fig.add_subplot(ndim+1, 1, ndim+1))
Gregory Ashton's avatar
Gregory Ashton committed
1288

1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
            if plot_det_stat:
                lnl = sampler.lnlikelihood[temp, :, :]
                if burnin_idx and add_det_stat_burnin:
                    burn_in_vals = lnl[:, :burnin_idx].flatten