mcmc_based_searches.py 71 KB
Newer Older
Gregory Ashton's avatar
Gregory Ashton committed
1
2
""" Searches using MCMC-based methods """

3
import sys
Gregory Ashton's avatar
Gregory Ashton committed
4
import os
5
import copy
Gregory Ashton's avatar
Gregory Ashton committed
6
import logging
7
from collections import OrderedDict
8
9
10
11
12
13
14
15

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import emcee
import corner
import dill as pickle

Gregory Ashton's avatar
Gregory Ashton committed
16
17
18
19
from core import BaseSearchClass, ComputeFstat
from core import tqdm, args, earth_ephem, sun_ephem
from optimal_setup_functions import get_optimal_setup
import helper_functions
20
21


Gregory Ashton's avatar
Gregory Ashton committed
22
23
class MCMCSearch(BaseSearchClass):
    """ MCMC search using ComputeFstat"""
Gregory Ashton's avatar
Gregory Ashton committed
24
    @helper_functions.initializer
25
    def __init__(self, label, outdir, sftfilepath, theta_prior, tref,
Gregory Ashton's avatar
Gregory Ashton committed
26
                 minStartTime, maxStartTime, nsteps=[100, 100],
27
28
                 nwalkers=100, ntemps=1, log10temperature_min=-5,
                 theta_initial=None, scatter_val=1e-10,
29
30
                 binary=False, BSGL=False, minCoverFreq=None,
                 maxCoverFreq=None, detector=None, earth_ephem=None,
31
                 sun_ephem=None, injectSources=None):
32
33
34
35
        """
        Parameters
        label, outdir: str
            A label and directory to read/write data from/to
36
37
        sftfilepath: str
            File patern to match SFTs
38
        theta_prior: dict
39
40
41
42
            Dictionary of priors and fixed values for the search parameters.
            For each parameters (key of the dict), if it is to be held fixed
            the value should be the constant float, if it is be searched, the
            value should be a dictionary of the prior.
43
44
45
46
        theta_initial: dict, array, (None)
            Either a dictionary of distribution about which to distribute the
            initial walkers about, an array (from which the walkers will be
            scattered by scatter_val, or  None in which case the prior is used.
47
        tref, minStartTime, maxStartTime: int
48
49
50
51
52
53
            GPS seconds of the reference time, start time and end time
        nsteps: list (m,)
            List specifying the number of steps to take, the last two entries
            give the nburn and nprod of the 'production' run, all entries
            before are for iterative initialisation steps (usually just one)
            e.g. [1000, 1000, 500].
54
55
56
57
58
59
60
61
62
63
64
        nwalkers, ntemps: int,
            The number of walkers and temperates to use in the parallel
            tempered PTSampler.
        log10temperature_min float < 0
            The  log_10(tmin) value, the set of betas passed to PTSampler are
            generated from np.logspace(0, log10temperature_min, ntemps).
        binary: Bool
            If true, search over binary parameters
        detector: str
            Two character reference to the data to use, specify None for no
            contraint.
65
66
67
68
69
70
71
72
73
74
        minCoverFreq, maxCoverFreq: float
            Minimum and maximum instantaneous frequency which will be covered
            over the SFT time span as passed to CreateFstatInput
        earth_ephem, sun_ephem: str
            Paths of the two files containing positions of Earth and Sun,
            respectively at evenly spaced times, as passed to CreateFstatInput
            If None defaults defined in BaseSearchClass will be used

        """

Gregory Ashton's avatar
Gregory Ashton committed
75
76
        if os.path.isdir(outdir) is False:
            os.mkdir(outdir)
77
        self.add_log_file()
Gregory Ashton's avatar
Gregory Ashton committed
78
79
        logging.info(
            'Set-up MCMC search for model {} on data {}'.format(
80
                self.label, self.sftfilepath))
81
82
83
        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
        self.unpack_input_theta()
        self.ndim = len(self.theta_keys)
84
85
86
87
        if self.log10temperature_min:
            self.betas = np.logspace(0, self.log10temperature_min, self.ntemps)
        else:
            self.betas = None
88

89
90
91
92
93
94
95
96
        if earth_ephem is None:
            self.earth_ephem = self.earth_ephem_default
        if sun_ephem is None:
            self.sun_ephem = self.sun_ephem_default

        if args.clean and os.path.isfile(self.pickle_path):
            os.rename(self.pickle_path, self.pickle_path+".old")

97
98
99
        self.log_input()

    def log_input(self):
100
        logging.info('theta_prior = {}'.format(self.theta_prior))
101
        logging.info('nwalkers={}'.format(self.nwalkers))
102
103
104
105
        logging.info('scatter_val = {}'.format(self.scatter_val))
        logging.info('nsteps = {}'.format(self.nsteps))
        logging.info('ntemps = {}'.format(self.ntemps))
        logging.info('log10temperature_min = {}'.format(
106
            self.log10temperature_min))
107
108
109

    def inititate_search_object(self):
        logging.info('Setting up search object')
Gregory Ashton's avatar
Gregory Ashton committed
110
        self.search = ComputeFstat(
111
112
113
114
            tref=self.tref, sftfilepath=self.sftfilepath,
            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
            earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
            detector=self.detector, BSGL=self.BSGL, transient=False,
115
            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
116
            binary=self.binary, injectSources=self.injectSources)
117
118

    def logp(self, theta_vals, theta_prior, theta_keys, search):
Gregory Ashton's avatar
Gregory Ashton committed
119
        H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
120
121
122
123
124
125
             zip(theta_vals, theta_keys)]
        return np.sum(H)

    def logl(self, theta, search):
        for j, theta_i in enumerate(self.theta_idxs):
            self.fixed_theta[theta_i] = theta[j]
126
127
        FS = search.compute_fullycoherent_det_stat_single_point(
            *self.fixed_theta)
128
129
130
        return FS

    def unpack_input_theta(self):
131
        full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']
132
133
134
        if self.binary:
            full_theta_keys += [
                'asini', 'period', 'ecc', 'tp', 'argp']
135
136
        full_theta_keys_copy = copy.copy(full_theta_keys)

137
138
        full_theta_symbols = ['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
                              r'$\delta$']
139
140
141
142
        if self.binary:
            full_theta_symbols += [
                'asini', 'period', 'period', 'ecc', 'tp', 'argp']

143
144
        self.theta_keys = []
        fixed_theta_dict = {}
145
        for key, val in self.theta_prior.iteritems():
146
147
            if type(val) is dict:
                fixed_theta_dict[key] = 0
Gregory Ashton's avatar
Gregory Ashton committed
148
                self.theta_keys.append(key)
149
150
151
152
153
154
            elif type(val) in [float, int, np.float64]:
                fixed_theta_dict[key] = val
            else:
                raise ValueError(
                    'Type {} of {} in theta not recognised'.format(
                        type(val), key))
Gregory Ashton's avatar
Gregory Ashton committed
155
            full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

        if len(full_theta_keys_copy) > 0:
            raise ValueError(('Input dictionary `theta` is missing the'
                              'following keys: {}').format(
                                  full_theta_keys_copy))

        self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
        self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
        self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]

        idxs = np.argsort(self.theta_idxs)
        self.theta_idxs = [self.theta_idxs[i] for i in idxs]
        self.theta_symbols = [self.theta_symbols[i] for i in idxs]
        self.theta_keys = [self.theta_keys[i] for i in idxs]

    def check_initial_points(self, p0):
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        for nt in range(self.ntemps):
            logging.info('Checking temperature {} chains'.format(nt))
            initial_priors = np.array([
                self.logp(p, self.theta_prior, self.theta_keys, self.search)
                for p in p0[nt]])
            number_of_initial_out_of_bounds = sum(initial_priors == -np.inf)

            if number_of_initial_out_of_bounds > 0:
                logging.warning(
                    'Of {} initial values, {} are -np.inf due to the prior'
                    .format(len(initial_priors),
                            number_of_initial_out_of_bounds))

                p0 = self.generate_new_p0_to_fix_initial_points(
                    p0, nt, initial_priors)

    def generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors):
        logging.info('Attempting to correct intial values')
        idxs = np.arange(self.nwalkers)[initial_priors == -np.inf]
        count = 0
        while sum(initial_priors == -np.inf) > 0 and count < 100:
            for j in idxs:
                p0[nt][j] = (p0[nt][np.random.randint(0, self.nwalkers)]*(
                             1+np.random.normal(0, 1e-10, self.ndim)))
            initial_priors = np.array([
                self.logp(p, self.theta_prior, self.theta_keys,
                          self.search)
                for p in p0[nt]])
            count += 1

        if sum(initial_priors == -np.inf) > 0:
            logging.info('Failed to fix initial priors')
        else:
            logging.info('Suceeded to fix initial priors')

        return p0
208

Gregory Ashton's avatar
Gregory Ashton committed
209
    def run_sampler_with_progress_bar(self, sampler, ns, p0):
210
211
        for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
            pass
Gregory Ashton's avatar
Gregory Ashton committed
212
213
        return sampler

214
    def run(self, proposal_scale_factor=2, create_plots=True, **kwargs):
215

Gregory Ashton's avatar
Gregory Ashton committed
216
        self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        if self.old_data_is_okay_to_use is True:
            logging.warning('Using saved data from {}'.format(
                self.pickle_path))
            d = self.get_saved_data()
            self.sampler = d['sampler']
            self.samples = d['samples']
            self.lnprobs = d['lnprobs']
            self.lnlikes = d['lnlikes']
            return

        self.inititate_search_object()

        sampler = emcee.PTSampler(
            self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
            logpargs=(self.theta_prior, self.theta_keys, self.search),
232
            loglargs=(self.search,), betas=self.betas, a=proposal_scale_factor)
233

Gregory Ashton's avatar
Gregory Ashton committed
234
235
        p0 = self.generate_initial_p0()
        p0 = self.apply_corrections_to_p0(p0)
236
237
238
239
240
        self.check_initial_points(p0)

        ninit_steps = len(self.nsteps) - 2
        for j, n in enumerate(self.nsteps[:-2]):
            logging.info('Running {}/{} initialisation with {} steps'.format(
Gregory Ashton's avatar
Gregory Ashton committed
241
                j, ninit_steps, n))
Gregory Ashton's avatar
Gregory Ashton committed
242
            sampler = self.run_sampler_with_progress_bar(sampler, n, p0)
243
244
            logging.info("Mean acceptance fraction: {}"
                         .format(np.mean(sampler.acceptance_fraction, axis=1)))
245
246
247
            if self.ntemps > 1:
                logging.info("Tswap acceptance fraction: {}"
                             .format(sampler.tswap_acceptance_fraction))
248
249
250
251
252
253
254
            if create_plots:
                fig, axes = self.plot_walkers(sampler,
                                              symbols=self.theta_symbols,
                                              **kwargs)
                fig.tight_layout()
                fig.savefig('{}/{}_init_{}_walkers.png'.format(
                    self.outdir, self.label, j), dpi=200)
255

256
            p0 = self.get_new_p0(sampler)
Gregory Ashton's avatar
Gregory Ashton committed
257
            p0 = self.apply_corrections_to_p0(p0)
258
259
260
            self.check_initial_points(p0)
            sampler.reset()

Gregory Ashton's avatar
Gregory Ashton committed
261
262
263
264
        if len(self.nsteps) > 1:
            nburn = self.nsteps[-2]
        else:
            nburn = 0
265
266
267
        nprod = self.nsteps[-1]
        logging.info('Running final burn and prod with {} steps'.format(
            nburn+nprod))
Gregory Ashton's avatar
Gregory Ashton committed
268
        sampler = self.run_sampler_with_progress_bar(sampler, nburn+nprod, p0)
269
270
        logging.info("Mean acceptance fraction: {}"
                     .format(np.mean(sampler.acceptance_fraction, axis=1)))
271
272
273
        if self.ntemps > 1:
            logging.info("Tswap acceptance fraction: {}"
                         .format(sampler.tswap_acceptance_fraction))
274

275
276
277
278
279
280
        if create_plots:
            fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
                                          burnin_idx=nburn, **kwargs)
            fig.tight_layout()
            fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
                        dpi=200)
281
282
283
284
285
286
287
288
289
290

        samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
        lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1))
        lnlikes = sampler.lnlikelihood[0, :, nburn:].reshape((-1))
        self.sampler = sampler
        self.samples = samples
        self.lnprobs = lnprobs
        self.lnlikes = lnlikes
        self.save_data(sampler, samples, lnprobs, lnlikes)

291
    def plot_corner(self, figsize=(7, 7),  tglitch_ratio=False,
292
293
294
                    add_prior=False, nstds=None, label_offset=0.4,
                    dpi=300, rc_context={}, **kwargs):

Gregory Ashton's avatar
Gregory Ashton committed
295
296
297
298
299
300
301
302
303
304
        if self.ndim < 2:
            with plt.rc_context(rc_context):
                fig, ax = plt.subplots(figsize=figsize)
                ax.hist(self.samples, bins=50, histtype='stepfilled')
                ax.set_xlabel(self.theta_symbols[0])

            fig.savefig('{}/{}_corner.png'.format(
                self.outdir, self.label), dpi=dpi)
            return

305
306
307
308
309
310
        with plt.rc_context(rc_context):
            fig, axes = plt.subplots(self.ndim, self.ndim,
                                     figsize=figsize)

            samples_plt = copy.copy(self.samples)
            theta_symbols_plt = copy.copy(self.theta_symbols)
311
312
            theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}')
                                 for s in theta_symbols_plt]
313
314
315
316
317

            if tglitch_ratio:
                for j, k in enumerate(self.theta_keys):
                    if k == 'tglitch':
                        s = samples_plt[:, j]
318
319
320
                        samples_plt[:, j] = (
                            s - self.minStartTime)/(
                                self.maxStartTime - self.minStartTime)
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
                        theta_symbols_plt[j] = r'$R_{\textrm{glitch}}$'

            if type(nstds) is int and 'range' not in kwargs:
                _range = []
                for j, s in enumerate(samples_plt.T):
                    median = np.median(s)
                    std = np.std(s)
                    _range.append((median - nstds*std, median + nstds*std))
            else:
                _range = None

            fig_triangle = corner.corner(samples_plt,
                                         labels=theta_symbols_plt,
                                         fig=fig,
                                         bins=50,
                                         max_n_ticks=4,
                                         plot_contours=True,
                                         plot_datapoints=True,
                                         label_kwargs={'fontsize': 8},
                                         data_kwargs={'alpha': 0.1,
                                                      'ms': 0.5},
                                         range=_range,
                                         **kwargs)

            axes_list = fig_triangle.get_axes()
            axes = np.array(axes_list).reshape(self.ndim, self.ndim)
            plt.draw()
            for ax in axes[:, 0]:
                ax.yaxis.set_label_coords(-label_offset, 0.5)
            for ax in axes[-1, :]:
                ax.xaxis.set_label_coords(0.5, -label_offset)
            for ax in axes_list:
                ax.set_rasterized(True)
                ax.set_rasterization_zorder(-10)
            plt.tight_layout(h_pad=0.0, w_pad=0.0)
            fig.subplots_adjust(hspace=0.05, wspace=0.05)

            if add_prior:
                self.add_prior_to_corner(axes, samples_plt)

            fig_triangle.savefig('{}/{}_corner.png'.format(
                self.outdir, self.label), dpi=dpi)
363
364
365
366
367
368

    def add_prior_to_corner(self, axes, samples):
        for i, key in enumerate(self.theta_keys):
            ax = axes[i][i]
            xlim = ax.get_xlim()
            s = samples[:, i]
Gregory Ashton's avatar
Gregory Ashton committed
369
            prior = self.generic_lnprior(**self.theta_prior[key])
370
371
372
373
374
375
            x = np.linspace(s.min(), s.max(), 100)
            ax2 = ax.twinx()
            ax2.get_yaxis().set_visible(False)
            ax2.plot(x, [prior(xi) for xi in x], '-r')
            ax.set_xlim(xlim)

376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    def plot_prior_posterior(self, normal_stds=2):
        """ Plot the posterior in the context of the prior """
        fig, axes = plt.subplots(nrows=self.ndim, figsize=(8, 4*self.ndim))
        N = 1000
        from scipy.stats import gaussian_kde

        for i, (ax, key) in enumerate(zip(axes, self.theta_keys)):
            prior_dict = self.theta_prior[key]
            prior_func = self.generic_lnprior(**prior_dict)
            if prior_dict['type'] == 'unif':
                x = np.linspace(prior_dict['lower'], prior_dict['upper'], N)
                prior = prior_func(x)
                prior[0] = 0
                prior[-1] = 0
            elif prior_dict['type'] == 'norm':
                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = prior_func(x)
395
396
397
398
399
            elif prior_dict['type'] == 'halfnorm':
                lower = prior_dict['loc']
                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = [prior_func(xi) for xi in x]
Gregory Ashton's avatar
Gregory Ashton committed
400
401
402
403
404
            elif prior_dict['type'] == 'neghalfnorm':
                upper = prior_dict['loc']
                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = [prior_func(xi) for xi in x]
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
            else:
                raise ValueError('Not implemented for prior type {}'.format(
                    prior_dict['type']))
            priorln = ax.plot(x, prior, 'r', label='prior')
            ax.set_xlabel(self.theta_symbols[i])

            s = self.samples[:, i]
            while len(s) > 10**4:
                # random downsample to avoid slow calculation of kde
                s = np.random.choice(s, size=int(len(s)/2.))
            kde = gaussian_kde(s)
            ax2 = ax.twinx()
            postln = ax2.plot(x, kde.pdf(x), 'k', label='posterior')
            ax2.set_yticklabels([])
            ax.set_yticklabels([])

        lns = priorln + postln
        labs = [l.get_label() for l in lns]
        axes[0].legend(lns, labs, loc=1, framealpha=0.8)

        fig.savefig('{}/{}_prior_posterior.png'.format(
            self.outdir, self.label))

428
    def plot_cumulative_max(self, **kwargs):
Gregory Ashton's avatar
Gregory Ashton committed
429
430
431
432
        d, maxtwoF = self.get_max_twoF()
        for key, val in self.theta_prior.iteritems():
            if key not in d:
                d[key] = val
433
434
435
436
437
438

        if hasattr(self, 'search') is False:
            self.inititate_search_object()
        if self.binary is False:
            self.search.plot_twoF_cumulative(
                self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'],
439
                Alpha=d['Alpha'], Delta=d['Delta'],
440
                tstart=self.minStartTime, tend=self.maxStartTime,
441
                **kwargs)
442
443
444
445
446
        else:
            self.search.plot_twoF_cumulative(
                self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'],
                Alpha=d['Alpha'], Delta=d['Delta'], asini=d['asini'],
                period=d['period'], ecc=d['ecc'], argp=d['argp'], tp=d['argp'],
447
                tstart=self.minStartTime, tend=self.maxStartTime, **kwargs)
Gregory Ashton's avatar
Gregory Ashton committed
448

Gregory Ashton's avatar
Gregory Ashton committed
449
    def generic_lnprior(self, **kwargs):
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        """ Return a lambda function of the pdf

        Parameters
        ----------
        kwargs: dict
            A dictionary containing 'type' of pdf and shape parameters

        """

        def logunif(x, a, b):
            above = x < b
            below = x > a
            if type(above) is not np.ndarray:
                if above and below:
                    return -np.log(b-a)
                else:
                    return -np.inf
            else:
                idxs = np.array([all(tup) for tup in zip(above, below)])
                p = np.zeros(len(x)) - np.inf
                p[idxs] = -np.log(b-a)
                return p

        def halfnorm(x, loc, scale):
474
            if x < loc:
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
                return -np.inf
            else:
                return -0.5*((x-loc)**2/scale**2+np.log(0.5*np.pi*scale**2))

        def cauchy(x, x0, gamma):
            return 1.0/(np.pi*gamma*(1+((x-x0)/gamma)**2))

        def exp(x, x0, gamma):
            if x > x0:
                return np.log(gamma) - gamma*(x - x0)
            else:
                return -np.inf

        if kwargs['type'] == 'unif':
            return lambda x: logunif(x, kwargs['lower'], kwargs['upper'])
        elif kwargs['type'] == 'halfnorm':
            return lambda x: halfnorm(x, kwargs['loc'], kwargs['scale'])
492
493
        elif kwargs['type'] == 'neghalfnorm':
            return lambda x: halfnorm(-x, kwargs['loc'], kwargs['scale'])
494
495
496
497
498
499
500
        elif kwargs['type'] == 'norm':
            return lambda x: -0.5*((x - kwargs['loc'])**2/kwargs['scale']**2
                                   + np.log(2*np.pi*kwargs['scale']**2))
        else:
            logging.info("kwargs:", kwargs)
            raise ValueError("Print unrecognise distribution")

Gregory Ashton's avatar
Gregory Ashton committed
501
    def generate_rv(self, **kwargs):
502
503
504
505
506
507
508
509
        dist_type = kwargs.pop('type')
        if dist_type == "unif":
            return np.random.uniform(low=kwargs['lower'], high=kwargs['upper'])
        if dist_type == "norm":
            return np.random.normal(loc=kwargs['loc'], scale=kwargs['scale'])
        if dist_type == "halfnorm":
            return np.abs(np.random.normal(loc=kwargs['loc'],
                                           scale=kwargs['scale']))
510
511
512
        if dist_type == "neghalfnorm":
            return -1 * np.abs(np.random.normal(loc=kwargs['loc'],
                                                scale=kwargs['scale']))
513
514
515
516
517
518
        if dist_type == "lognorm":
            return np.random.lognormal(
                mean=kwargs['loc'], sigma=kwargs['scale'])
        else:
            raise ValueError("dist_type {} unknown".format(dist_type))

Gregory Ashton's avatar
Gregory Ashton committed
519
    def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
Gregory Ashton's avatar
Gregory Ashton committed
520
                     lw=0.1, burnin_idx=None, add_det_stat_burnin=False,
521
                     fig=None, axes=None, xoffset=0, plot_det_stat=True,
522
                     context='classic', subtractions=None, labelpad=0.05):
523
524
        """ Plot all the chains from a sampler """

525
526
527
        if np.ndim(axes) > 1:
            axes = axes.flatten()

528
529
530
531
532
533
534
535
536
537
538
539
540
        shape = sampler.chain.shape
        if len(shape) == 3:
            nwalkers, nsteps, ndim = shape
            chain = sampler.chain[:, :, :]
        if len(shape) == 4:
            ntemps, nwalkers, nsteps, ndim = shape
            if temp < ntemps:
                logging.info("Plotting temperature {} chains".format(temp))
            else:
                raise ValueError(("Requested temperature {} outside of"
                                  "available range").format(temp))
            chain = sampler.chain[temp, :, :, :]

541
542
        if subtractions is None:
            subtractions = [0 for i in range(ndim)]
543
544
545
        else:
            if len(subtractions) != self.ndim:
                raise ValueError('subtractions must be of length ndim')
546

547
        with plt.style.context((context)):
Gregory Ashton's avatar
Gregory Ashton committed
548
            plt.rcParams['text.usetex'] = True
Gregory Ashton's avatar
Gregory Ashton committed
549
            if fig is None and axes is None:
550
                fig = plt.figure(figsize=(4, 3.0*ndim))
Gregory Ashton's avatar
Gregory Ashton committed
551
                ax = fig.add_subplot(ndim+1, 1, 1)
Gregory Ashton's avatar
Gregory Ashton committed
552
                axes = [ax] + [fig.add_subplot(ndim+1, 1, i)
Gregory Ashton's avatar
Gregory Ashton committed
553
                               for i in range(2, ndim+1)]
554

Gregory Ashton's avatar
Gregory Ashton committed
555
            idxs = np.arange(chain.shape[1])
556
557
            if ndim > 1:
                for i in range(ndim):
558
                    axes[i].ticklabel_format(useOffset=False, axis='y')
Gregory Ashton's avatar
Gregory Ashton committed
559
560
                    cs = chain[:, :, i].T
                    if burnin_idx:
Gregory Ashton's avatar
Gregory Ashton committed
561
                        axes[i].plot(xoffset+idxs[:burnin_idx],
562
563
                                     cs[:burnin_idx]-subtractions[i],
                                     color="r", alpha=alpha,
Gregory Ashton's avatar
Gregory Ashton committed
564
                                     lw=lw)
565
566
                    axes[i].plot(xoffset+idxs[burnin_idx:],
                                 cs[burnin_idx:]-subtractions[i],
Gregory Ashton's avatar
Gregory Ashton committed
567
                                 color="k", alpha=alpha, lw=lw)
568
                    if symbols:
569
                        if subtractions[i] == 0:
570
                            axes[i].set_ylabel(symbols[i], labelpad=labelpad)
571
572
                        else:
                            axes[i].set_ylabel(
573
574
                                symbols[i]+'$-$'+symbols[i]+'$_0$',
                                labelpad=labelpad)
575

576
            else:
Gregory Ashton's avatar
Gregory Ashton committed
577
                axes[0].ticklabel_format(useOffset=False, axis='y')
Gregory Ashton's avatar
Gregory Ashton committed
578
                cs = chain[:, :, temp].T
Gregory Ashton's avatar
Gregory Ashton committed
579
580
581
582
583
584
                if burnin_idx:
                    axes[0].plot(idxs[:burnin_idx], cs[:burnin_idx],
                                 color="r", alpha=alpha, lw=lw)
                axes[0].plot(idxs[burnin_idx:], cs[burnin_idx:], color="k",
                             alpha=alpha, lw=lw)
                if symbols:
585
                    axes[0].set_ylabel(symbols[0], labelpad=labelpad)
586

587
            if plot_det_stat:
588
589
590
                if len(axes) == ndim:
                    axes.append(fig.add_subplot(ndim+1, 1, ndim+1))

591
592
593
                lnl = sampler.lnlikelihood[temp, :, :]
                if burnin_idx and add_det_stat_burnin:
                    burn_in_vals = lnl[:, :burnin_idx].flatten()
594
595
596
597
598
599
600
                    try:
                        axes[-1].hist(burn_in_vals[~np.isnan(burn_in_vals)],
                                      bins=50, histtype='step', color='r')
                    except ValueError:
                        logging.info('Det. Stat. hist failed, most likely all '
                                     'values where the same')
                        pass
601
602
603
                else:
                    burn_in_vals = []
                prod_vals = lnl[:, burnin_idx:].flatten()
604
605
606
607
608
609
610
                try:
                    axes[-1].hist(prod_vals[~np.isnan(prod_vals)], bins=50,
                                  histtype='step', color='k')
                except ValueError:
                    logging.info('Det. Stat. hist failed, most likely all '
                                 'values where the same')
                    pass
611
612
613
614
615
616
617
618
619
620
621
622
                if self.BSGL:
                    axes[-1].set_xlabel(r'$\mathcal{B}_\mathrm{S/GL}$')
                else:
                    axes[-1].set_xlabel(r'$\widetilde{2\mathcal{F}}$')
                axes[-1].set_ylabel(r'$\textrm{Counts}$')
                combined_vals = np.append(burn_in_vals, prod_vals)
                if len(combined_vals) > 0:
                    minv = np.min(combined_vals)
                    maxv = np.max(combined_vals)
                    Range = abs(maxv-minv)
                    axes[-1].set_xlim(minv-0.1*Range, maxv+0.1*Range)

623
624
625
626
                xfmt = matplotlib.ticker.ScalarFormatter()
                xfmt.set_powerlimits((-4, 4)) 
                axes[-1].xaxis.set_major_formatter(xfmt)

627
            axes[-2].set_xlabel(r'$\textrm{Number of steps}$', labelpad=0.2)
628
629
        return fig, axes

Gregory Ashton's avatar
Gregory Ashton committed
630
631
632
633
634
    def apply_corrections_to_p0(self, p0):
        """ Apply any correction to the initial p0 values """
        return p0

    def generate_scattered_p0(self, p):
635
        """ Generate a set of p0s scattered about p """
Gregory Ashton's avatar
Gregory Ashton committed
636
        p0 = [[p + self.scatter_val * p * np.random.randn(self.ndim)
637
638
639
640
               for i in xrange(self.nwalkers)]
              for j in xrange(self.ntemps)]
        return p0

Gregory Ashton's avatar
Gregory Ashton committed
641
    def generate_initial_p0(self):
642
643
644
        """ Generate a set of init vals for the walkers """

        if type(self.theta_initial) == dict:
645
            logging.info('Generate initial values from initial dictionary')
646
            if hasattr(self, 'nglitch') and self.nglitch > 1:
647
                raise ValueError('Initial dict not implemented for nglitch>1')
Gregory Ashton's avatar
Gregory Ashton committed
648
            p0 = [[[self.generate_rv(**self.theta_initial[key])
649
650
651
                    for key in self.theta_keys]
                   for i in range(self.nwalkers)]
                  for j in range(self.ntemps)]
652
653
654
655
656
657
        elif type(self.theta_initial) == list:
            logging.info('Generate initial values from list of theta_initial')
            p0 = [[[self.generate_rv(**val)
                    for val in self.theta_initial]
                   for i in range(self.nwalkers)]
                  for j in range(self.ntemps)]
658
        elif self.theta_initial is None:
659
            logging.info('Generate initial values from prior dictionary')
Gregory Ashton's avatar
Gregory Ashton committed
660
            p0 = [[[self.generate_rv(**self.theta_prior[key])
661
662
663
664
                    for key in self.theta_keys]
                   for i in range(self.nwalkers)]
                  for j in range(self.ntemps)]
        elif len(self.theta_initial) == self.ndim:
Gregory Ashton's avatar
Gregory Ashton committed
665
            p0 = self.generate_scattered_p0(self.theta_initial)
666
667
668
669
670
        else:
            raise ValueError('theta_initial not understood')

        return p0

671
    def get_new_p0(self, sampler):
672
673
674
675
676
677
        """ Returns new initial positions for walkers are burn0 stage

        This returns new positions for all walkers by scattering points about
        the maximum posterior with scale `scatter_val`.

        """
Gregory Ashton's avatar
Gregory Ashton committed
678
679
680
681
        temp_idx = 0
        pF = sampler.chain[temp_idx, :, :, :]
        lnl = sampler.lnlikelihood[temp_idx, :, :]
        lnp = sampler.lnprobability[temp_idx, :, :]
682
683

        # General warnings about the state of lnp
Gregory Ashton's avatar
Gregory Ashton committed
684
        if np.any(np.isnan(lnp)):
685
686
            logging.warning(
                "Of {} lnprobs {} are nan".format(
Gregory Ashton's avatar
Gregory Ashton committed
687
688
                    np.shape(lnp), np.sum(np.isnan(lnp))))
        if np.any(np.isposinf(lnp)):
689
690
            logging.warning(
                "Of {} lnprobs {} are +np.inf".format(
Gregory Ashton's avatar
Gregory Ashton committed
691
692
                    np.shape(lnp), np.sum(np.isposinf(lnp))))
        if np.any(np.isneginf(lnp)):
693
694
            logging.warning(
                "Of {} lnprobs {} are -np.inf".format(
Gregory Ashton's avatar
Gregory Ashton committed
695
                    np.shape(lnp), np.sum(np.isneginf(lnp))))
696

697
698
        lnp_finite = copy.copy(lnp)
        lnp_finite[np.isinf(lnp)] = np.nan
Gregory Ashton's avatar
Gregory Ashton committed
699
700
        idx = np.unravel_index(np.nanargmax(lnp_finite), lnp_finite.shape)
        p = pF[idx]
701
        p0 = self.generate_scattered_p0(p)
702

703
704
705
706
707
708
709
710
        self.search.BSGL = False
        twoF = self.logl(p, self.search)
        self.search.BSGL = self.BSGL

        logging.info(('Gen. new p0 from pos {} which had det. stat.={:2.1f},'
                      ' twoF={:2.1f} and lnp={:2.1f}')
                     .format(idx[1], lnl[idx], twoF, lnp_finite[idx]))

711
712
713
714
715
        return p0

    def get_save_data_dictionary(self):
        d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
                 ntemps=self.ntemps, theta_keys=self.theta_keys,
Gregory Ashton's avatar
Gregory Ashton committed
716
                 theta_prior=self.theta_prior, scatter_val=self.scatter_val,
717
                 log10temperature_min=self.log10temperature_min,
718
                 BSGL=self.BSGL)
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
        return d

    def save_data(self, sampler, samples, lnprobs, lnlikes):
        d = self.get_save_data_dictionary()
        d['sampler'] = sampler
        d['samples'] = samples
        d['lnprobs'] = lnprobs
        d['lnlikes'] = lnlikes

        if os.path.isfile(self.pickle_path):
            logging.info('Saving backup of {} as {}.old'.format(
                self.pickle_path, self.pickle_path))
            os.rename(self.pickle_path, self.pickle_path+".old")
        with open(self.pickle_path, "wb") as File:
            pickle.dump(d, File)

    def get_saved_data(self):
        with open(self.pickle_path, "r") as File:
            d = pickle.load(File)
        return d

    def check_old_data_is_okay_to_use(self):
741
742
743
744
        if args.use_old_data:
            logging.info("Forcing use of old data")
            return True

745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
        if os.path.isfile(self.pickle_path) is False:
            logging.info('No pickled data found')
            return False

        oldest_sft = min([os.path.getmtime(f) for f in
                          self.get_list_of_matching_sfts()])
        if os.path.getmtime(self.pickle_path) < oldest_sft:
            logging.info('Pickled data outdates sft files')
            return False

        old_d = self.get_saved_data().copy()
        new_d = self.get_save_data_dictionary().copy()

        old_d.pop('samples')
        old_d.pop('sampler')
        old_d.pop('lnprobs')
        old_d.pop('lnlikes')

        mod_keys = []
        for key in new_d.keys():
            if key in old_d:
                if new_d[key] != old_d[key]:
                    mod_keys.append((key, old_d[key], new_d[key]))
            else:
769
                raise ValueError('Keys {} not in old dictionary'.format(key))
770
771
772
773
774
775
776
777
778

        if len(mod_keys) == 0:
            return True
        else:
            logging.warning("Saved data differs from requested")
            logging.info("Differences found in following keys:")
            for key in mod_keys:
                if len(key) == 3:
                    if np.isscalar(key[1]) or key[0] == 'nsteps':
779
                        logging.info("    {} : {} -> {}".format(*key))
780
                    else:
781
                        logging.info("    " + key[0])
782
783
784
785
786
                else:
                    logging.info(key)
            return False

    def get_max_twoF(self, threshold=0.05):
787
        """ Returns the max likelihood sample and the corresponding 2F value
788
789
790
791
792
793
794
795
796
797
798
799
800
801

        Note: the sample is returned as a dictionary along with an estimate of
        the standard deviation calculated from the std of all samples with a
        twoF within `threshold` (relative) to the max twoF

        """
        if any(np.isposinf(self.lnlikes)):
            logging.info('twoF values contain positive infinite values')
        if any(np.isneginf(self.lnlikes)):
            logging.info('twoF values contain negative infinite values')
        if any(np.isnan(self.lnlikes)):
            logging.info('twoF values contain nan')
        idxs = np.isfinite(self.lnlikes)
        jmax = np.nanargmax(self.lnlikes[idxs])
802
        maxlogl = self.lnlikes[jmax]
803
        d = OrderedDict()
804

805
806
807
808
809
810
811
812
813
814
        if self.BSGL:
            if hasattr(self, 'search') is False:
                self.inititate_search_object()
            p = self.samples[jmax]
            self.search.BSGL = False
            maxtwoF = self.logl(p, self.search)
            self.search.BSGL = self.BSGL
        else:
            maxtwoF = maxlogl

Gregory Ashton's avatar
Gregory Ashton committed
815
        repeats = []
816
        for i, k in enumerate(self.theta_keys):
Gregory Ashton's avatar
Gregory Ashton committed
817
818
819
820
821
822
823
824
825
826
            if k in d and k not in repeats:
                d[k+'_0'] = d[k]  # relabel the old key
                d.pop(k)
                repeats.append(k)
            if k in repeats:
                k = k + '_0'
                count = 1
                while k in d:
                    k = k.replace('_{}'.format(count-1), '_{}'.format(count))
                    count += 1
827
828
829
830
831
            d[k] = self.samples[jmax][i]
        return d, maxtwoF

    def get_median_stds(self):
        """ Returns a dict of the median and std of all production samples """
832
        d = OrderedDict()
Gregory Ashton's avatar
Gregory Ashton committed
833
        repeats = []
834
        for s, k in zip(self.samples.T, self.theta_keys):
Gregory Ashton's avatar
Gregory Ashton committed
835
836
837
838
839
840
841
842
843
844
845
846
847
            if k in d and k not in repeats:
                d[k+'_0'] = d[k]  # relabel the old key
                d[k+'_0_std'] = d[k+'_std']
                d.pop(k)
                d.pop(k+'_std')
                repeats.append(k)
            if k in repeats:
                k = k + '_0'
                count = 1
                while k in d:
                    k = k.replace('_{}'.format(count-1), '_{}'.format(count))
                    count += 1

848
849
850
851
852
853
854
855
            d[k] = np.median(s)
            d[k+'_std'] = np.std(s)
        return d

    def write_par(self, method='med'):
        """ Writes a .par of the best-fit params with an estimated std """
        logging.info('Writing {}/{}.par using the {} method'.format(
            self.outdir, self.label, method))
856
857
858
859

        median_std_d = self.get_median_stds()
        max_twoF_d, max_twoF = self.get_max_twoF()

Gregory Ashton's avatar
Gregory Ashton committed
860
        logging.info('Writing par file with max twoF = {}'.format(max_twoF))
861
862
863
        filename = '{}/{}.par'.format(self.outdir, self.label)
        with open(filename, 'w+') as f:
            f.write('MaxtwoF = {}\n'.format(max_twoF))
Gregory Ashton's avatar
Gregory Ashton committed
864
            f.write('tref = {}\n'.format(self.tref))
865
866
            if hasattr(self, 'theta0_index'):
                f.write('theta0_index = {}\n'.format(self.theta0_idx))
867
            if method == 'med':
868
869
                for key, val in median_std_d.iteritems():
                    f.write('{} = {:1.16e}\n'.format(key, val))
870
            if method == 'twoFmax':
871
872
873
874
                for key, val in max_twoF_d.iteritems():
                    f.write('{} = {:1.16e}\n'.format(key, val))

    def print_summary(self):
Gregory Ashton's avatar
Gregory Ashton committed
875
        max_twoFd, max_twoF = self.get_max_twoF()
876
        median_std_d = self.get_median_stds()
Gregory Ashton's avatar
Gregory Ashton committed
877
        logging.info('Summary:')
878
        if hasattr(self, 'theta0_idx'):
Gregory Ashton's avatar
Gregory Ashton committed
879
880
            logging.info('theta0 index: {}'.format(self.theta0_idx))
        logging.info('Max twoF: {} with parameters:'.format(max_twoF))
Gregory Ashton's avatar
Gregory Ashton committed
881
882
        for k in np.sort(max_twoFd.keys()):
            print('  {:10s} = {:1.9e}'.format(k, max_twoFd[k]))
Gregory Ashton's avatar
Gregory Ashton committed
883
        logging.info('Median +/- std for production values')
884
        for k in np.sort(median_std_d.keys()):
885
            if 'std' not in k:
Gregory Ashton's avatar
Gregory Ashton committed
886
                logging.info('  {:10s} = {:1.9e} +/- {:1.9e}'.format(
887
                    k, median_std_d[k], median_std_d[k+'_std']))
Gregory Ashton's avatar
Gregory Ashton committed
888
        logging.info('\n')
889

890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
    def CF_twoFmax(self, theta, twoFmax, ntrials):
        Fmax = twoFmax/2.0
        return (np.exp(1j*theta*twoFmax)*ntrials/2.0
                * Fmax*np.exp(-Fmax)*(1-(1+Fmax)*np.exp(-Fmax))**(ntrials-1))

    def pdf_twoFhat(self, twoFhat, nglitch, ntrials, twoFmax=100, dtwoF=0.1):
        if np.ndim(ntrials) == 0:
            ntrials = np.zeros(nglitch+1) + ntrials
        twoFmax_int = np.arange(0, twoFmax, dtwoF)
        theta_int = np.arange(-1/dtwoF, 1./dtwoF, 1./twoFmax)
        CF_twoFmax_theta = np.array(
            [[np.trapz(self.CF_twoFmax(t, twoFmax_int, ntrial), twoFmax_int)
              for t in theta_int]
             for ntrial in ntrials])
        CF_twoFhat_theta = np.prod(CF_twoFmax_theta, axis=0)
        pdf = (1/(2*np.pi)) * np.array(
            [np.trapz(np.exp(-1j*theta_int*twoFhat_val)
             * CF_twoFhat_theta, theta_int) for twoFhat_val in twoFhat])
        return pdf.real

    def p_val_twoFhat(self, twoFhat, ntrials, twoFhatmax=500, Npoints=1000):
911
        """ Caluculate the p-value for the given twoFhat in Gaussian noise
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930

        Parameters
        ----------
        twoFhat: float
            The observed twoFhat value
        ntrials: int, array of len Nglitch+1
            The number of trials for each glitch+1
        """
        twoFhats = np.linspace(twoFhat, twoFhatmax, Npoints)
        pdf = self.pdf_twoFhat(twoFhats, self.nglitch, ntrials)
        return np.trapz(pdf, twoFhats)

    def get_p_value(self, delta_F0, time_trials=0):
        """ Get's the p-value for the maximum twoFhat value """
        d, max_twoF = self.get_max_twoF()
        if self.nglitch == 1:
            tglitches = [d['tglitch']]
        else:
            tglitches = [d['tglitch_{}'.format(i)] for i in range(self.nglitch)]
931
        tboundaries = [self.minStartTime] + tglitches + [self.maxStartTime]
932
        deltaTs = np.diff(tboundaries)
933
934
        ntrials = [time_trials + delta_F0 * dT for dT in deltaTs]
        p_val = self.p_val_twoFhat(max_twoF, ntrials)
935
        print('p-value = {}'.format(p_val))
936
937
        return p_val

938
    def get_evidence(self):
939
940
941
942
943
944
        fburnin = float(self.nsteps[-2])/np.sum(self.nsteps[-2:])
        lnev, lnev_err = self.sampler.thermodynamic_integration_log_evidence(
            fburnin=fburnin)

        log10evidence = lnev/np.log(10)
        log10evidence_err = lnev_err/np.log(10)
945
946
947
948
        return log10evidence, log10evidence_err

    def compute_evidence_long(self):
        """ Computes the evidence/marginal likelihood for the model """
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
        betas = self.betas
        alllnlikes = self.sampler.lnlikelihood[:, :, self.nsteps[-2]:]
        mean_lnlikes = np.mean(np.mean(alllnlikes, axis=1), axis=1)

        mean_lnlikes = mean_lnlikes[::-1]
        betas = betas[::-1]

        fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(6, 8))

        if any(np.isinf(mean_lnlikes)):
            print("WARNING mean_lnlikes contains inf: recalculating without"
                  " the {} infs".format(len(betas[np.isinf(mean_lnlikes)])))
            idxs = np.isinf(mean_lnlikes)
            mean_lnlikes = mean_lnlikes[~idxs]
            betas = betas[~idxs]
            log10evidence = np.trapz(mean_lnlikes, betas)/np.log(10)
            z1 = np.trapz(mean_lnlikes, betas)
            z2 = np.trapz(mean_lnlikes[::-1][::2][::-1],
                          betas[::-1][::2][::-1])
            log10evidence_err = np.abs(z1 - z2) / np.log(10)

        ax1.semilogx(betas, mean_lnlikes, "-o")
        ax1.set_xlabel(r"$\beta$")
        ax1.set_ylabel(r"$\langle \log(\mathcal{L}) \rangle$")
        print("log10 evidence for {} = {} +/- {}".format(
              self.label, log10evidence, log10evidence_err))
        min_betas = []
        evidence = []
        for i in range(len(betas)/2):
            min_betas.append(betas[i])
            lnZ = np.trapz(mean_lnlikes[i:], betas[i:])
            evidence.append(lnZ/np.log(10))

        ax2.semilogx(min_betas, evidence, "-o")
        ax2.set_ylabel(r"$\int_{\beta_{\textrm{Min}}}^{\beta=1}" +
                       r"\langle \log(\mathcal{L})\rangle d\beta$", size=16)
        ax2.set_xlabel(r"$\beta_{\textrm{min}}$")
        plt.tight_layout()
        fig.savefig("{}/{}_beta_lnl.png".format(self.outdir, self.label))

989

Gregory Ashton's avatar
Gregory Ashton committed
990
991
class MCMCGlitchSearch(MCMCSearch):
    """ MCMC search using the SemiCoherentGlitchSearch """
Gregory Ashton's avatar
Gregory Ashton committed
992
    @helper_functions.initializer
993
    def __init__(self, label, outdir, sftfilepath, theta_prior, tref,
Gregory Ashton's avatar
Gregory Ashton committed
994
                 minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100],
995
996
997
998
                 nwalkers=100, ntemps=1, log10temperature_min=-5,
                 theta_initial=None, scatter_val=1e-10, dtglitchmin=1*86400,
                 theta0_idx=0, detector=None, BSGL=False, minCoverFreq=None,
                 maxCoverFreq=None, earth_ephem=None, sun_ephem=None):
Gregory Ashton's avatar
Gregory Ashton committed
999
1000
        """
        Parameters
Gregory Ashton's avatar
Gregory Ashton committed
1001
        ----------
Gregory Ashton's avatar
Gregory Ashton committed
1002
1003
        label, outdir: str
            A label and directory to read/write data from/to
Gregory Ashton's avatar
Gregory Ashton committed
1004
        sftfilepath: str
1005
            File patern to match SFTs
Gregory Ashton's avatar
Gregory Ashton committed
1006
1007
1008
1009
1010
1011
1012
1013
        theta_prior: dict
            Dictionary of priors and fixed values for the search parameters.
            For each parameters (key of the dict), if it is to be held fixed
            the value should be the constant float, if it is be searched, the
            value should be a dictionary of the prior.
        theta_initial: dict, array, (None)
            Either a dictionary of distribution about which to distribute the
            initial walkers about, an array (from which the walkers will be
Gregory Ashton's avatar
Gregory Ashton committed
1014
            scattered by scatter_val), or None in which case the prior is used.
1015
1016
1017
1018
        scatter_val, float or ndim array
            Size of scatter to use about the initialisation step, if given as
            an array it must be of length ndim and the order is given by
            theta_keys
Gregory Ashton's avatar
Gregory Ashton committed
1019
1020
        nglitch: int
            The number of glitches to allow
1021
        tref, minStartTime, maxStartTime: int
Gregory Ashton's avatar
Gregory Ashton committed
1022
1023
1024
1025
1026
1027
1028
1029
1030
            GPS seconds of the reference time, start time and end time
        nsteps: list (m,)
            List specifying the number of steps to take, the last two entries
            give the nburn and nprod of the 'production' run, all entries
            before are for iterative initialisation steps (usually just one)
            e.g. [1000, 1000, 500].
        dtglitchmin: int
            The minimum duration (in seconds) of a segment between two glitches
            or a glitch and the start/end of the data
1031
1032
1033
1034
1035
1036
        nwalkers, ntemps: int,
            The number of walkers and temperates to use in the parallel
            tempered PTSampler.
        log10temperature_min float < 0
            The  log_10(tmin) value, the set of betas passed to PTSampler are
            generated from np.logspace(0, log10temperature_min, ntemps).
1037
1038
1039
1040
        theta0_idx, int
            Index (zero-based) of which segment the theta refers to - uyseful
            if providing a tight prior on theta to allow the signal to jump
            too theta (and not just from)
1041
1042
1043
        detector: str
            Two character reference to the data to use, specify None for no
            contraint.
Gregory Ashton's avatar
Gregory Ashton committed
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
        minCoverFreq, maxCoverFreq: float
            Minimum and maximum instantaneous frequency which will be covered
            over the SFT time span as passed to CreateFstatInput
        earth_ephem, sun_ephem: str
            Paths of the two files containing positions of Earth and Sun,
            respectively at evenly spaced times, as passed to CreateFstatInput
            If None defaults defined in BaseSearchClass will be used

        """

Gregory Ashton's avatar
Gregory Ashton committed
1054
1055
        if os.path.isdir(outdir) is False:
            os.mkdir(outdir)
1056
        self.add_log_file()
Gregory Ashton's avatar
Gregory Ashton committed
1057
1058
        logging.info(('Set-up MCMC glitch search with {} glitches for model {}'
                      ' on data {}').format(self.nglitch, self.label,
1059
                                            self.sftfilepath))
Gregory Ashton's avatar
Gregory Ashton committed
1060
1061
1062
        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
        self.unpack_input_theta()
        self.ndim = len(self.theta_keys)
1063
1064
1065
1066
        if self.log10temperature_min:
            self.betas = np.logspace(0, self.log10temperature_min, self.ntemps)
        else:
            self.betas = None
Gregory Ashton's avatar
Gregory Ashton committed
1067
1068
1069
1070
1071
1072
1073
1074
1075
        if earth_ephem is None:
            self.earth_ephem = self.earth_ephem_default
        if sun_ephem is None:
            self.sun_ephem = self.sun_ephem_default

        if args.clean and os.path.isfile(self.pickle_path):
            os.rename(self.pickle_path, self.pickle_path+".old")

        self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
1076
        self.log_input()
Gregory Ashton's avatar
Gregory Ashton committed
1077
1078
1079
1080

    def inititate_search_object(self):
        logging.info('Setting up search object')
        self.search = SemiCoherentGlitchSearch(
1081
            label=self.label, outdir=self.outdir, sftfilepath=self.sftfilepath,
1082
1083
            tref=self.tref, minStartTime=self.minStartTime,
            maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
Gregory Ashton's avatar
Gregory Ashton committed
1084
            maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
1085
            sun_ephem=self.sun_ephem, detector=self.detector, BSGL=self.BSGL,
1086
            nglitch=self.nglitch, theta0_idx=self.theta0_idx)
Gregory Ashton's avatar
Gregory Ashton committed
1087
1088
1089

    def logp(self, theta_vals, theta_prior, theta_keys, search):
        if self.nglitch > 1:
1090
1091
            ts = ([self.minStartTime] + list(theta_vals[-self.nglitch:])
                  + [self.maxStartTime])
Gregory Ashton's avatar
Gregory Ashton committed
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
            if np.array_equal(ts, np.sort(ts)) is False:
                return -np.inf
            if any(np.diff(ts) < self.dtglitchmin):
                return -np.inf

        H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
             zip(theta_vals, theta_keys)]
        return np.sum(H)

    def logl(self, theta, search):
Gregory Ashton's avatar
Gregory Ashton committed
1102
        if self.nglitch > 1:
1103
1104
            ts = ([self.minStartTime] + list(theta_vals[-self.nglitch:])
                  + [self.maxStartTime])
Gregory Ashton's avatar
Gregory Ashton committed
1105
1106
1107
            if np.array_equal(ts, np.sort(ts)) is False:
                return -np.inf

Gregory Ashton's avatar
Gregory Ashton committed
1108
1109
1110
1111
1112
1113
1114
1115
1116
        for j, theta_i in enumerate(self.theta_idxs):
            self.fixed_theta[theta_i] = theta[j]
        FS = search.compute_nglitch_fstat(*self.fixed_theta)
        return FS

    def unpack_input_theta(self):
        glitch_keys = ['delta_F0', 'delta_F1', 'tglitch']
        full_glitch_keys = list(np.array(
            [[gk]*self.nglitch for gk in glitch_keys]).flatten())
1117
1118
1119
1120

        if 'tglitch_0' in self.theta_prior:
            full_glitch_keys[-self.nglitch:] = [
                'tglitch_{}'.format(i) for i in range(self.nglitch)]
1121
1122
1123
1124
            full_glitch_keys[-2*self.nglitch:-1*self.nglitch] = [
                'delta_F1_{}'.format(i) for i in range(self.nglitch)]
            full_glitch_keys[-4*self.nglitch:-2*self.nglitch] = [
                'delta_F0_{}'.format(i) for i in range(self.nglitch)]
Gregory Ashton's avatar
Gregory Ashton committed
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
        full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']+full_glitch_keys
        full_theta_keys_copy = copy.copy(full_theta_keys)

        glitch_symbols = ['$\delta f$', '$\delta \dot{f}$', r'$t_{glitch}$']
        full_glitch_symbols = list(np.array(
            [[gs]*self.nglitch for gs in glitch_symbols]).flatten())
        full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
                               r'$\delta$'] + full_glitch_symbols)
        self.theta_keys = []
        fixed_theta_dict = {}
        for key, val in self.theta_prior.iteritems():
            if type(val) is dict:
                fixed_theta_dict[key] = 0
                if key in glitch_keys:
                    for i in range(self.nglitch):
                        self.theta_keys.append(key)
                else:
                    self.theta_keys.append(key)
            elif type(val) in [float, int, np.float64]:
                fixed_theta_dict[key] = val
            else:
                raise ValueError(
                    'Type {} of {} in theta not recognised'.format(
                        type(val), key))
            if key in glitch_keys:
                for i in range(self.nglitch):
                    full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
            else:
                full_theta_keys_copy.pop(full_theta_keys_copy.index(key))

        if len(full_theta_keys_copy) > 0:
            raise ValueError(('Input dictionary `theta` is missing the'
                              'following keys: {}').format(
                                  full_theta_keys_copy))

        self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
        self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
        self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]

        idxs = np.argsort(self.theta_idxs)
        self.theta_idxs = [self.theta_idxs[i] for i in idxs]
        self.theta_symbols = [self.theta_symbols[i] for i in idxs]
        self.theta_keys = [self.theta_keys[i] for i in idxs]

        # Correct for number of glitches in the idxs
        self.theta_idxs = np.array(self.theta_idxs)
        while np.sum(self.theta_idxs[:-1] == self.theta_idxs[1:]) > 0:
            for i, idx in enumerate(self.theta_idxs):
                if idx in self.theta_idxs[:i]:
                    self.theta_idxs[i] += 1

1176
1177
1178
1179
1180
1181
1182
1183
    def get_save_data_dictionary(self):
        d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
                 ntemps=self.ntemps, theta_keys=self.theta_keys,
                 theta_prior=self.theta_prior, scatter_val=self.scatter_val,
                 log10temperature_min=self.log10temperature_min,
                 theta0_idx=self.theta0_idx, BSGL=self.BSGL)
        return d

Gregory Ashton's avatar
Gregory Ashton committed
1184
1185
1186
1187
1188
1189
1190
    def apply_corrections_to_p0(self, p0):
        p0 = np.array(p0)
        if self.nglitch > 1:
            p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:],
                                               axis=2)
        return p0

Gregory Ashton's avatar
Gregory Ashton committed
1191
1192
1193
1194
1195
1196
1197
1198
    def plot_cumulative_max(self):

        fig, ax = plt.subplots()
        d, maxtwoF = self.get_max_twoF()
        for key, val in self.theta_prior.iteritems():
            if key not in d:
                d[key] = val

1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
        if self.nglitch > 1:
            delta_F0s = [d['delta_F0_{}'.format(i)] for i in
                         range(self.nglitch)]
            delta_F0s.insert(self.theta0_idx, 0)
            delta_F0s = np.array(delta_F0s)
            delta_F0s[:self.theta0_idx] *= -1
            tglitches = [d['tglitch_{}'.format(i)] for i in
                         range(self.nglitch)]
        elif self.nglitch == 1:
            delta_F0s = [d['delta_F0']]
            delta_F0s.insert(self.theta0_idx, 0)
            delta_F0s = np.array(delta_F0s)
            delta_F0s[:self.theta0_idx] *= -1
            tglitches = [d['tglitch']]
Gregory Ashton's avatar
Gregory Ashton committed
1213

1214
        tboundaries = [self.minStartTime] + tglitches + [self.maxStartTime]
Gregory Ashton's avatar
Gregory Ashton committed
1215
1216

        for j in range(self.nglitch+1):
1217
1218
            ts = tboundaries[j]
            te = tboundaries[j+1]
Gregory Ashton's avatar
Gregory Ashton committed
1219
1220
1221
1222
1223
1224
1225
1226
            if (te - ts)/86400 < 5:
                logging.info('Period too short to perform cumulative search')
                continue
            if j < self.theta0_idx:
                summed_deltaF0 = np.sum(delta_F0s[j:self.theta0_idx])
                F0_j = d['F0'] - summed_deltaF0
                taus, twoFs = self.search.calculate_twoF_cumulative(
                    F0_j, F1=d['F1'], F2=d['F2'], Alpha=d['Alpha'],
1227
                    Delta=d['Delta'], tstart=ts, tend=te)
Gregory Ashton's avatar
Gregory Ashton committed
1228
1229
1230
1231
1232
1233

            elif j >= self.theta0_idx:
                summed_deltaF0 = np.sum(delta_F0s[self.theta0_idx:j+1])
                F0_j = d['F0'] + summed_deltaF0
                taus, twoFs = self.search.calculate_twoF_cumulative(
                    F0_j, F1=d['F1'], F2=d['F2'], Alpha=d['Alpha'],
1234
                    Delta=d['Delta'], tstart=ts, tend=te)
Gregory Ashton's avatar
Gregory Ashton committed
1235
1236
1237
1238
1239
            ax.plot(ts+taus, twoFs)

        ax.set_xlabel('GPS time')
        fig.savefig('{}/{}_twoFcumulative.png'.format(self.outdir, self.label))

Gregory Ashton's avatar
Gregory Ashton committed
1240

1241
1242
class MCMCSemiCoherentSearch(MCMCSearch):
    """ MCMC search for a signal using the semi-coherent ComputeFstat """
Gregory Ashton's avatar
Gregory Ashton committed
1243
    @helper_functions.initializer