mcmc_based_searches.py 90.6 KB
Newer Older
Gregory Ashton's avatar
Gregory Ashton committed
1
""" Searches using MCMC-based methods """
2
from __future__ import division, absolute_import, print_function
Gregory Ashton's avatar
Gregory Ashton committed
3

4
import sys
Gregory Ashton's avatar
Gregory Ashton committed
5
import os
6
import copy
Gregory Ashton's avatar
Gregory Ashton committed
7
import logging
8
from collections import OrderedDict
9
import subprocess
10
11
12
13
14
15
16
17

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import emcee
import corner
import dill as pickle

18
import pyfstat.core as core
19
20
from pyfstat.core import tqdm, args, read_par
import pyfstat.optimal_setup_functions as optimal_setup_functions
21
import pyfstat.helper_functions as helper_functions
22
23


24
class MCMCSearch(core.BaseSearchClass):
Gregory Ashton's avatar
Gregory Ashton committed
25
    """MCMC search using ComputeFstat
26
27
28
29
30
31
32
33
34
35
36
37

    Parameters
    ----------
    label, outdir: str
        A label and directory to read/write data from/to
    theta_prior: dict
        Dictionary of priors and fixed values for the search parameters.
        For each parameters (key of the dict), if it is to be held fixed
        the value should be the constant float, if it is be searched, the
        value should be a dictionary of the prior.
    tref, minStartTime, maxStartTime: int
        GPS seconds of the reference time, start time and end time
Gregory Ashton's avatar
Gregory Ashton committed
38
39
40
41
42
43
    sftfilepattern: str
        Pattern to match SFTs using wildcards (*?) and ranges [0-9];
        mutiple patterns can be given separated by colons.
    detectors: str
        Two character reference to the detectors to use, specify None for no
        contraint and comma separate for multiple references.
44
45
46
47
    nsteps: list (2,)
        Number of burn-in and production steps to take, [nburn, nprod]. See
        `pyfstat.MCMCSearch.setup_initialisation()` for details on adding
        initialisation steps.
48
49
50
    nwalkers, ntemps: int,
        The number of walkers and temperates to use in the parallel
        tempered PTSampler.
51
52
53
54
    log10beta_min float < 0
        The  log_10(beta) value, if given the set of betas passed to PTSampler
        are generated from `np.logspace(0, log10beta_min, ntemps)` (given
        in descending order to emcee).
Gregory Ashton's avatar
Gregory Ashton committed
55
    theta_initial: dict, array, (None)
56
57
        A dictionary of distribution about which to distribute the
        initial walkers about
Gregory Ashton's avatar
Gregory Ashton committed
58
    rhohatmax: float,
59
60
61
        Upper bound for the SNR scale parameter (required to normalise the
        Bayes factor) - this needs to be carefully set when using the
        evidence.
Gregory Ashton's avatar
Gregory Ashton committed
62
    binary: bool
63
        If true, search over binary parameters
Gregory Ashton's avatar
Gregory Ashton committed
64
65
66
67
    BSGL: bool
        If true, use the BSGL statistic
    SSBPrec: int
        SSBPrec (SSB precision) to use when calling ComputeFstat
68
69
70
    minCoverFreq, maxCoverFreq: float
        Minimum and maximum instantaneous frequency which will be covered
        over the SFT time span as passed to CreateFstatInput
Gregory Ashton's avatar
Gregory Ashton committed
71
72
73
74
75
    injectSources: dict
        If given, inject these properties into the SFT files before running
        the search
    assumeSqrtSX: float
        Don't estimate noise-floors, but assume (stationary) per-IFO sqrt{SX}
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    Attributes
    ----------
    symbol_dictionary: dict
        Key, val pairs of the parameters (i.e. `F0`, `F1`), to Latex math
        symbols for plots
    unit_dictionary: dict
        Key, val pairs of the parameters (i.e. `F0`, `F1`), and the
        units (i.e. `Hz`)
    transform_dictionary: dict
        Key, val pairs of the parameters (i.e. `F0`, `F1`), where the key is
        itself a dictionary which can item `multiplier`, `subtractor`, or
        `unit` by which to transform by and update the units.

    """
91
92

    symbol_dictionary = dict(
93
        F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', Alpha=r'$\alpha$',
94
95
        Delta='$\delta$', asini='asini', period='P', ecc='ecc', tp='tp',
        argp='argp')
96
    unit_dictionary = dict(
97
98
        F0='Hz', F1='Hz/s', F2='Hz/s$^2$', Alpha=r'rad', Delta='rad',
        asini='', period='s', ecc='', tp='', argp='')
99
    transform_dictionary = {}
100

Gregory Ashton's avatar
Gregory Ashton committed
101
    @helper_functions.initializer
Gregory Ashton's avatar
Gregory Ashton committed
102
    def __init__(self, label, outdir, theta_prior, tref, minStartTime,
Gregory Ashton's avatar
Gregory Ashton committed
103
104
                 maxStartTime, sftfilepattern=None, detectors=None,
                 nsteps=[100, 100], nwalkers=100, ntemps=1,
105
                 log10beta_min=-5, theta_initial=None,
106
                 rhohatmax=1000, binary=False, BSGL=False,
Gregory Ashton's avatar
Gregory Ashton committed
107
                 SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
108
                 injectSources=None, assumeSqrtSX=None):
109

Gregory Ashton's avatar
Gregory Ashton committed
110
111
        if os.path.isdir(outdir) is False:
            os.mkdir(outdir)
112
        self._add_log_file()
113
        logging.info('Set-up MCMC search for model {}'.format(self.label))
114
115
        if sftfilepattern:
            logging.info('Using data {}'.format(self.sftfilepattern))
116
        else:
117
            logging.info('No sftfilepattern given')
118
119
        if injectSources:
            logging.info('Inject sources: {}'.format(injectSources))
120
        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
121
        self._unpack_input_theta()
122
        self.ndim = len(self.theta_keys)
123
124
        if self.log10beta_min:
            self.betas = np.logspace(0, self.log10beta_min, self.ntemps)
125
126
        else:
            self.betas = None
127

128
129
130
        if args.clean and os.path.isfile(self.pickle_path):
            os.rename(self.pickle_path, self.pickle_path+".old")

131
        self._set_likelihoodcoef()
132
        self._log_input()
133
134
135

    def _set_likelihoodcoef(self):
        self.likelihoodcoef = np.log(70./self.rhohatmax**4)
136

137
    def _log_input(self):
138
        logging.info('theta_prior = {}'.format(self.theta_prior))
139
        logging.info('nwalkers={}'.format(self.nwalkers))
140
141
        logging.info('nsteps = {}'.format(self.nsteps))
        logging.info('ntemps = {}'.format(self.ntemps))
142
143
        logging.info('log10beta_min = {}'.format(
            self.log10beta_min))
144

145
    def _initiate_search_object(self):
146
        logging.info('Setting up search object')
147
        self.search = core.ComputeFstat(
148
            tref=self.tref, sftfilepattern=self.sftfilepattern,
149
            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
150
            detectors=self.detectors, BSGL=self.BSGL, transient=False,
151
            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
152
            binary=self.binary, injectSources=self.injectSources,
153
            assumeSqrtSX=self.assumeSqrtSX, SSBprec=self.SSBprec)
154
155

    def logp(self, theta_vals, theta_prior, theta_keys, search):
156
        H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
157
158
159
160
161
162
             zip(theta_vals, theta_keys)]
        return np.sum(H)

    def logl(self, theta, search):
        for j, theta_i in enumerate(self.theta_idxs):
            self.fixed_theta[theta_i] = theta[j]
163
164
        FS = search.compute_fullycoherent_det_stat_single_point(
            *self.fixed_theta)
165
        return FS + self.likelihoodcoef
166

167
    def _unpack_input_theta(self):
168
        full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']
169
170
171
        if self.binary:
            full_theta_keys += [
                'asini', 'period', 'ecc', 'tp', 'argp']
172
173
        full_theta_keys_copy = copy.copy(full_theta_keys)

174
175
        full_theta_symbols = ['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
                              r'$\delta$']
176
177
        if self.binary:
            full_theta_symbols += [
178
                'asini', 'period', 'ecc', 'tp', 'argp']
179

180
181
        self.theta_keys = []
        fixed_theta_dict = {}
182
        for key, val in self.theta_prior.iteritems():
183
184
            if type(val) is dict:
                fixed_theta_dict[key] = 0
Gregory Ashton's avatar
Gregory Ashton committed
185
                self.theta_keys.append(key)
186
187
188
189
190
191
            elif type(val) in [float, int, np.float64]:
                fixed_theta_dict[key] = val
            else:
                raise ValueError(
                    'Type {} of {} in theta not recognised'.format(
                        type(val), key))
Gregory Ashton's avatar
Gregory Ashton committed
192
            full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

        if len(full_theta_keys_copy) > 0:
            raise ValueError(('Input dictionary `theta` is missing the'
                              'following keys: {}').format(
                                  full_theta_keys_copy))

        self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
        self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
        self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]

        idxs = np.argsort(self.theta_idxs)
        self.theta_idxs = [self.theta_idxs[i] for i in idxs]
        self.theta_symbols = [self.theta_symbols[i] for i in idxs]
        self.theta_keys = [self.theta_keys[i] for i in idxs]

208
    def _check_initial_points(self, p0):
209
210
211
212
213
214
215
216
217
218
219
220
221
        for nt in range(self.ntemps):
            logging.info('Checking temperature {} chains'.format(nt))
            initial_priors = np.array([
                self.logp(p, self.theta_prior, self.theta_keys, self.search)
                for p in p0[nt]])
            number_of_initial_out_of_bounds = sum(initial_priors == -np.inf)

            if number_of_initial_out_of_bounds > 0:
                logging.warning(
                    'Of {} initial values, {} are -np.inf due to the prior'
                    .format(len(initial_priors),
                            number_of_initial_out_of_bounds))

222
                p0 = self._generate_new_p0_to_fix_initial_points(
223
224
                    p0, nt, initial_priors)

225
    def _generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors):
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        logging.info('Attempting to correct intial values')
        idxs = np.arange(self.nwalkers)[initial_priors == -np.inf]
        count = 0
        while sum(initial_priors == -np.inf) > 0 and count < 100:
            for j in idxs:
                p0[nt][j] = (p0[nt][np.random.randint(0, self.nwalkers)]*(
                             1+np.random.normal(0, 1e-10, self.ndim)))
            initial_priors = np.array([
                self.logp(p, self.theta_prior, self.theta_keys,
                          self.search)
                for p in p0[nt]])
            count += 1

        if sum(initial_priors == -np.inf) > 0:
            logging.info('Failed to fix initial priors')
        else:
            logging.info('Suceeded to fix initial priors')

        return p0
245

246
247
    def setup_burnin_convergence_testing(
            self, n=10, test_type='autocorr', windowed=False, **kwargs):
248
        """ Set up convergence testing during the MCMC simulation
249
250
251

        Parameters
        ----------
252
253
254
255
256
257
258
259
260
        n: int
            Number of steps after which to test convergence
        test_type: str ['autocorr', 'GR']
            If 'autocorr' use the exponential autocorrelation time (kwargs
            passed to `get_autocorr_convergence`). If 'GR' use the Gelman-Rubin
            statistic (kwargs passed to `get_GR_convergence`)
        windowed: bool
            If True, only calculate the convergence test in a window of length
            `n`
261
262
263
264
        **kwargs:
            Passed to either `_test_autocorr_convergence()` or
            `_test_GR_convergence()` depending on `test_type`.

265
        """
266
        logging.info('Setting up convergence testing')
267
268
269
270
        self.convergence_n = n
        self.convergence_windowed = windowed
        self.convergence_test_type = test_type
        self.convergence_kwargs = kwargs
271
272
        self.convergence_diagnostic = []
        self.convergence_diagnosticx = []
273
        if test_type in ['autocorr']:
274
            self._get_convergence_test = self._test_autocorr_convergence
275
        elif test_type in ['GR']:
276
            self._get_convergence_test = self._test_GR_convergence
277
278
279
        else:
            raise ValueError('test_type {} not understood'.format(test_type))

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    def setup_initialisation(self, nburn0, scatter_val=1e-10):
        """ Add an initialisation step to the MCMC run

        If called prior to `run()`, adds an intial step in which the MCMC
        simulation is run for `nburn0` steps. After this, the MCMC simulation
        continues in the usual manner (i.e. for nburn and nprod steps), but the
        walkers are reset scattered around the maximum likelihood position
        of the initialisation step.

        Parameters
        ----------
        nburn0: int
            Number of initialisation steps to take
        scatter_val: float
            Relative number to scatter walkers around the maximum likelihood
            position after the initialisation step

        """

        logging.info('Setting up initialisation with nburn0={}, scatter_val={}'
                     .format(nburn0, scatter_val))
        self.nsteps = [nburn0] + self.nsteps
        self.scatter_val = scatter_val

304
    def _test_autocorr_convergence(self, i, sampler, test=True, n_cut=5):
305
306
307
308
309
310
311
312
313
314
315
        try:
            acors = np.zeros((self.ntemps, self.ndim))
            for temp in range(self.ntemps):
                if self.convergence_windowed:
                    j = i-self.convergence_n
                else:
                    j = 0
                x = np.mean(sampler.chain[temp, :, j:i, :], axis=0)
                acors[temp, :] = emcee.autocorr.exponential_time(x)
            c = np.max(acors, axis=0)
        except emcee.autocorr.AutocorrError:
Gregory Ashton's avatar
Gregory Ashton committed
316
317
318
319
            logging.info('Failed to calculate exponential autocorrelation')
            c = np.zeros(self.ndim) + np.nan
        except AttributeError:
            logging.info('Unable to calculate exponential autocorrelation')
320
321
322
323
324
325
326
327
            c = np.zeros(self.ndim) + np.nan

        self.convergence_diagnosticx.append(i - self.convergence_n/2.)
        self.convergence_diagnostic.append(list(c))

        if test:
            return i > n_cut * np.max(c)

328
    def _test_GR_convergence(self, i, sampler, test=True, R=1.1):
329
330
331
332
333
        if self.convergence_windowed:
            s = sampler.chain[0, :, i-self.convergence_n+1:i+1, :]
        else:
            s = sampler.chain[0, :, :i+1, :]
        N = float(self.convergence_n)
334
335
        M = float(self.nwalkers)
        W = np.mean(np.var(s, axis=1), axis=0)
336
337
        per_walker_mean = np.mean(s, axis=1)
        mean = np.mean(per_walker_mean, axis=0)
338
339
        B = N / (M-1.) * np.sum((per_walker_mean-mean)**2, axis=0)
        Vhat = (N-1)/N * W + (M+1)/(M*N) * B
340
        c = np.sqrt(Vhat/W)
341
        self.convergence_diagnostic.append(c)
342
        self.convergence_diagnosticx.append(i - self.convergence_n/2.)
343

344
345
346
        if test and np.max(c) < R:
            return True
        else:
347
            return False
348
349
350
351

    def _test_convergence(self, i, sampler, **kwargs):
        if np.mod(i+1, self.convergence_n) == 0:
            return self._get_convergence_test(i, sampler, **kwargs)
352
        else:
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
            return False

    def _run_sampler_with_conv_test(self, sampler, p0, nprod=0, nburn=0):
        logging.info('Running {} burn-in steps with convergence testing'
                     .format(nburn))
        iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
        for i, output in enumerate(iterator):
            if self._test_convergence(i, sampler, test=True,
                                      **self.convergence_kwargs):
                logging.info(
                    'Converged at {} before max number {} of steps reached'
                    .format(i, nburn))
                self.convergence_idx = i
                break
        iterator.close()
        logging.info('Running {} production steps'.format(nprod))
        j = nburn
        iterator = tqdm(sampler.sample(output[0], iterations=nprod),
                        total=nprod)
        for result in iterator:
            self._test_convergence(j, sampler, test=False,
                                   **self.convergence_kwargs)
            j += 1
        return sampler
377

378
    def _run_sampler(self, sampler, p0, nprod=0, nburn=0):
379
380
        if hasattr(self, 'convergence_n'):
            self._run_sampler_with_conv_test(sampler, p0, nprod, nburn)
381
382
383
384
        else:
            for result in tqdm(sampler.sample(p0, iterations=nburn+nprod),
                               total=nburn+nprod):
                pass
385

386
387
        self.mean_acceptance_fraction = np.mean(
            sampler.acceptance_fraction, axis=1)
388
        logging.info("Mean acceptance fraction: {}"
389
                     .format(self.mean_acceptance_fraction))
390
        if self.ntemps > 1:
391
            self.tswap_acceptance_fraction = sampler.tswap_acceptance_fraction
392
393
394
            logging.info("Tswap acceptance fraction: {}"
                         .format(sampler.tswap_acceptance_fraction))
        try:
395
            self.autocorr_time = sampler.get_autocorr_time(c=4)
396
            logging.info("Autocorrelation length: {}".format(
397
                self.autocorr_time))
398
        except emcee.autocorr.AutocorrError as e:
399
            self.autocorr_time = np.nan
400
401
402
403
404
            logging.warning(
                'Autocorrelation calculation failed with message {}'.format(e))

        return sampler

405
    def _estimate_run_time(self):
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
        """ Print the estimated run time

        Uses timing coefficients based on a Lenovo T460p Intel(R)
        Core(TM) i5-6300HQ CPU @ 2.30GHz.

        """
        # Todo: add option to time on a machine, and move coefficients to
        # ~/.pyfstat.conf
        if (type(self.theta_prior['Alpha']) == dict or
                type(self.theta_prior['Delta']) == dict):
            tau0S = 7.3e-5
            tau0LD = 4.2e-7
        else:
            tau0S = 5.0e-5
            tau0LD = 6.2e-8
421
        Nsfts = (self.maxStartTime - self.minStartTime) / 1800.
422
423
424
        numb_evals = np.sum(self.nsteps)*self.nwalkers*self.ntemps
        a = tau0S * numb_evals
        b = tau0LD * Nsfts * numb_evals
425
426
427
        logging.info('Estimated run-time = {} s = {:1.0f}:{:1.0f} m'.format(
            a+b, *divmod(a+b, 60)))

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    def run(self, proposal_scale_factor=2, create_plots=True, c=5, **kwargs):
        """ Run the MCMC simulatation

        Parameters
        ----------
        proposal_scale_factor: float
            The proposal scale factor used by the sampler, see Goodman & Weare
            (2010). If the acceptance fraction is too low, you can raise it by
            decreasing the a parameter; and if it is too high, you can reduce
            it by increasing the a parameter [Foreman-Mackay (2013)].
        create_plots: bool
            If true, save trace plots of the walkers
        c: int
            The minimum number of autocorrelation times needed to trust the
            result when estimating the autocorrelation time (see
            emcee.autocorr.integrated_time for further details. Default is 5
        **kwargs:
            Passed to _plot_walkers to control the figures

447
448
449
450
451
        Returns
        -------
        sampler: emcee.ptsampler.PTSampler
            The emcee ptsampler object

452
        """
453

454
        self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use()
455
456
457
        if self.old_data_is_okay_to_use is True:
            logging.warning('Using saved data from {}'.format(
                self.pickle_path))
458
            d = self.get_saved_data_dictionary()
459
460
461
            self.samples = d['samples']
            self.lnprobs = d['lnprobs']
            self.lnlikes = d['lnlikes']
462
            self.all_lnlikelihood = d['all_lnlikelihood']
463
464
            return

465
        self._initiate_search_object()
466
        self._estimate_run_time()
467
468
469
470

        sampler = emcee.PTSampler(
            self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
            logpargs=(self.theta_prior, self.theta_keys, self.search),
471
            loglargs=(self.search,), betas=self.betas, a=proposal_scale_factor)
472

473
474
475
        p0 = self._generate_initial_p0()
        p0 = self._apply_corrections_to_p0(p0)
        self._check_initial_points(p0)
476

477
        # Run initialisation steps if required
478
479
480
        ninit_steps = len(self.nsteps) - 2
        for j, n in enumerate(self.nsteps[:-2]):
            logging.info('Running {}/{} initialisation with {} steps'.format(
Gregory Ashton's avatar
Gregory Ashton committed
481
                j, ninit_steps, n))
482
            sampler = self._run_sampler(sampler, p0, nburn=n)
483
            if create_plots:
484
                fig, axes = self._plot_walkers(sampler,
485
486
                                               symbols=self.theta_symbols,
                                               **kwargs)
487
488
                fig.tight_layout()
                fig.savefig('{}/{}_init_{}_walkers.png'.format(
Gregory Ashton's avatar
Gregory Ashton committed
489
                    self.outdir, self.label, j))
490

491
492
493
            p0 = self._get_new_p0(sampler)
            p0 = self._apply_corrections_to_p0(p0)
            self._check_initial_points(p0)
494
495
            sampler.reset()

Gregory Ashton's avatar
Gregory Ashton committed
496
497
498
499
        if len(self.nsteps) > 1:
            nburn = self.nsteps[-2]
        else:
            nburn = 0
500
501
502
        nprod = self.nsteps[-1]
        logging.info('Running final burn and prod with {} steps'.format(
            nburn+nprod))
503
        sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
504
        if create_plots:
505
            fig, axes = self._plot_walkers(sampler, symbols=self.theta_symbols,
506
                                           nprod=nprod, **kwargs)
507
508
            fig.tight_layout()
            fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
Gregory Ashton's avatar
Gregory Ashton committed
509
                        )
510
511
512
513

        samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
        lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1))
        lnlikes = sampler.lnlikelihood[0, :, nburn:].reshape((-1))
514
        all_lnlikelihood = sampler.lnlikelihood[:, :, nburn:]
515
516
517
        self.samples = samples
        self.lnprobs = lnprobs
        self.lnlikes = lnlikes
518
519
        self.all_lnlikelihood = all_lnlikelihood
        self._save_data(sampler, samples, lnprobs, lnlikes, all_lnlikelihood)
Gregory Ashton's avatar
Gregory Ashton committed
520
        return sampler
521

522
    def _get_rescale_multiplier_for_key(self, key):
523
        """ Get the rescale multiplier from the transform_dictionary
524
525
526
527
528

        Can either be a float, a string (in which case it is interpretted as
        a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent
        in which case 0 is returned
        """
529
        if key not in self.transform_dictionary:
530
531
            return 1

532
533
        if 'multiplier' in self.transform_dictionary[key]:
            val = self.transform_dictionary[key]['multiplier']
534
535
536
            if type(val) == str:
                if hasattr(self, val):
                    multiplier = getattr(
537
                        self, self.transform_dictionary[key]['multiplier'])
538
539
540
541
542
543
544
545
546
                else:
                    raise ValueError(
                        "multiplier {} not a class attribute".format(val))
            else:
                multiplier = val
        else:
            multiplier = 1
        return multiplier

547
    def _get_rescale_subtractor_for_key(self, key):
548
        """ Get the rescale subtractor from the transform_dictionary
549
550
551
552
553

        Can either be a float, a string (in which case it is interpretted as
        a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent
        in which case 0 is returned
        """
554
        if key not in self.transform_dictionary:
555
556
            return 0

557
558
        if 'subtractor' in self.transform_dictionary[key]:
            val = self.transform_dictionary[key]['subtractor']
559
560
561
            if type(val) == str:
                if hasattr(self, val):
                    subtractor = getattr(
562
                        self, self.transform_dictionary[key]['subtractor'])
563
564
565
566
567
568
569
570
571
                else:
                    raise ValueError(
                        "subtractor {} not a class attribute".format(val))
            else:
                subtractor = val
        else:
            subtractor = 0
        return subtractor

572
    def _scale_samples(self, samples, theta_keys):
573
        """ Scale the samples using the transform_dictionary """
574
        for key in theta_keys:
575
            if key in self.transform_dictionary:
576
577
                idx = theta_keys.index(key)
                s = samples[:, idx]
578
                subtractor = self._get_rescale_subtractor_for_key(key)
579
                s = s - subtractor
580
                multiplier = self._get_rescale_multiplier_for_key(key)
581
                s *= multiplier
582
583
                samples[:, idx] = s

584
585
        return samples

586
    def _get_labels(self):
587
        """ Combine the units, symbols and rescaling to give labels """
588

589
590
591
592
593
594
        labels = []
        for key in self.theta_keys:
            label = None
            s = self.symbol_dictionary[key]
            s.replace('_{glitch}', r'_\textrm{glitch}')
            u = self.unit_dictionary[key]
595
596
597
598
599
600
601
            if key in self.transform_dictionary:
                if 'symbol' in self.transform_dictionary[key]:
                    s = self.transform_dictionary[key]['symbol']
                if 'label' in self.transform_dictionary[key]:
                    label = self.transform_dictionary[key]['label']
                if 'unit' in self.transform_dictionary[key]:
                    u = self.transform_dictionary[key]['unit']
602
603
604
605
            if label is None:
                label = '{} \n [{}]'.format(s, u)
            labels.append(label)
        return labels
606

607
608
    def plot_corner(self, figsize=(7, 7), add_prior=False, nstds=None,
                    label_offset=0.4, dpi=300, rc_context={},
609
                    tglitch_ratio=False, fig_and_axes=None, save_fig=True,
610
                    **kwargs):
611
612
613
614
615
616
617
618
619
        """ Generate a corner plot of the posterior

        Using the `corner` package (https://pypi.python.org/pypi/corner/),
        generate estimates of the posterior from the production samples.

        Parameters
        ----------
        figsize: tuple (7, 7)
            Figure size in inches (passed to plt.subplots)
620
621
622
        add_prior: bool, str
            If true, plot the prior as a red line. If 'full' then for uniform
            priors plot the full extent of the prior.
623
624
625
626
627
628
629
630
631
632
633
634
635
636
        nstds: float
            The number of standard deviations to plot centered on the mean
        label_offset: float
            Offset the labels from the plot: useful to precent overlapping the
            tick labels with the axis labels
        dpi: int
            Passed to plt.savefig
        rc_context: dict
            Dictionary of rc values to set while generating the figure (see
            matplotlib rc for more details)
        tglitch_ratio: bool
            If true, and tglitch is a parameter, plot posteriors as the
            fractional time at which the glitch occurs instead of the actual
            time
637
638
639
640
641
        fig_and_axes: tuple
            fig and axes to plot on, the axes must be of the right shape,
            namely (ndim, ndim)
        save_fig: bool
            If true, save the figure, else return the fig, axes
642
643
        **kwargs:
            Passed to corner.corner
644

645
646
647
648
        Returns
        -------
        fig, axes:
            The matplotlib figure and axes, only returned if save_fig = False
649
650

        """
651

652
653
654
655
        if 'truths' in kwargs and len(kwargs['truths']) != self.ndim:
            logging.warning('len(Truths) != ndim, Truths will be ignored')
            kwargs['truths'] = None

Gregory Ashton's avatar
Gregory Ashton committed
656
657
        if self.ndim < 2:
            with plt.rc_context(rc_context):
658
659
660
661
                if fig_and_axes is None:
                    fig, ax = plt.subplots(figsize=figsize)
                else:
                    fig, ax = fig_and_axes
Gregory Ashton's avatar
Gregory Ashton committed
662
663
664
665
666
667
668
                ax.hist(self.samples, bins=50, histtype='stepfilled')
                ax.set_xlabel(self.theta_symbols[0])

            fig.savefig('{}/{}_corner.png'.format(
                self.outdir, self.label), dpi=dpi)
            return

669
        with plt.rc_context(rc_context):
670
671
672
673
674
            if fig_and_axes is None:
                fig, axes = plt.subplots(self.ndim, self.ndim,
                                         figsize=figsize)
            else:
                fig, axes = fig_and_axes
675
676

            samples_plt = copy.copy(self.samples)
677
            labels = self._get_labels()
678

679
            samples_plt = self._scale_samples(samples_plt, self.theta_keys)
680
681
682
683
684

            if tglitch_ratio:
                for j, k in enumerate(self.theta_keys):
                    if k == 'tglitch':
                        s = samples_plt[:, j]
685
686
687
                        samples_plt[:, j] = (
                            s - self.minStartTime)/(
                                self.maxStartTime - self.minStartTime)
688
                        labels[j] = r'$R_{\textrm{glitch}}$'
689
690
691
692
693
694
695

            if type(nstds) is int and 'range' not in kwargs:
                _range = []
                for j, s in enumerate(samples_plt.T):
                    median = np.median(s)
                    std = np.std(s)
                    _range.append((median - nstds*std, median + nstds*std))
696
697
            elif 'range' in kwargs:
                _range = kwargs.pop('range')
698
699
700
            else:
                _range = None

701
702
703
704
            hist_kwargs = kwargs.pop('hist_kwargs', dict())
            if 'normed' not in hist_kwargs:
                hist_kwargs['normed'] = True

705
            fig_triangle = corner.corner(samples_plt,
706
                                         labels=labels,
707
708
709
710
711
712
713
714
715
                                         fig=fig,
                                         bins=50,
                                         max_n_ticks=4,
                                         plot_contours=True,
                                         plot_datapoints=True,
                                         label_kwargs={'fontsize': 8},
                                         data_kwargs={'alpha': 0.1,
                                                      'ms': 0.5},
                                         range=_range,
716
                                         hist_kwargs=hist_kwargs,
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
                                         **kwargs)

            axes_list = fig_triangle.get_axes()
            axes = np.array(axes_list).reshape(self.ndim, self.ndim)
            plt.draw()
            for ax in axes[:, 0]:
                ax.yaxis.set_label_coords(-label_offset, 0.5)
            for ax in axes[-1, :]:
                ax.xaxis.set_label_coords(0.5, -label_offset)
            for ax in axes_list:
                ax.set_rasterized(True)
                ax.set_rasterization_zorder(-10)
            plt.tight_layout(h_pad=0.0, w_pad=0.0)
            fig.subplots_adjust(hspace=0.05, wspace=0.05)

            if add_prior:
733
                self._add_prior_to_corner(axes, self.samples, add_prior)
734

735
736
737
738
739
            if save_fig:
                fig_triangle.savefig('{}/{}_corner.png'.format(
                    self.outdir, self.label), dpi=dpi)
            else:
                return fig, axes
740

741
    def _add_prior_to_corner(self, axes, samples, add_prior):
742
743
744
        for i, key in enumerate(self.theta_keys):
            ax = axes[i][i]
            s = samples[:, i]
745
746
747
748
749
750
751
752
753
754
            lnprior = self._generic_lnprior(**self.theta_prior[key])
            if add_prior == 'full' and self.theta_prior[key]['type'] == 'unif':
                lower = self.theta_prior[key]['lower']
                upper = self.theta_prior[key]['upper']
                r = upper-lower
                xlim = [lower-0.05*r, upper+0.05*r]
                x = np.linspace(xlim[0], xlim[1], 1000)
            else:
                xlim = ax.get_xlim()
                x = np.linspace(s.min(), s.max(), 1000)
755
756
            multiplier = self._get_rescale_multiplier_for_key(key)
            subtractor = self._get_rescale_subtractor_for_key(key)
757
758
759
760
761
762
763
764
            ax.plot((x-subtractor)*multiplier,
                    [np.exp(lnprior(xi)) for xi in x], '-C3',
                    label='prior')

            for j in range(i, self.ndim):
                axes[j][i].set_xlim(xlim[0], xlim[1])
            for k in range(0, i):
                axes[i][k].set_ylim(xlim[0], xlim[1])
765

766
767
768
769
770
771
772
773
    def plot_prior_posterior(self, normal_stds=2):
        """ Plot the posterior in the context of the prior """
        fig, axes = plt.subplots(nrows=self.ndim, figsize=(8, 4*self.ndim))
        N = 1000
        from scipy.stats import gaussian_kde

        for i, (ax, key) in enumerate(zip(axes, self.theta_keys)):
            prior_dict = self.theta_prior[key]
774
            prior_func = self._generic_lnprior(**prior_dict)
775
776
777
778
779
            if prior_dict['type'] == 'unif':
                x = np.linspace(prior_dict['lower'], prior_dict['upper'], N)
                prior = prior_func(x)
                prior[0] = 0
                prior[-1] = 0
Gregory Ashton's avatar
Gregory Ashton committed
780
781
782
783
784
            elif prior_dict['type'] == 'log10unif':
                upper = prior_dict['log10upper']
                lower = prior_dict['log10lower']
                x = np.linspace(lower, upper, N)
                prior = [prior_func(xi) for xi in x]
785
786
787
788
789
            elif prior_dict['type'] == 'norm':
                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = prior_func(x)
790
791
792
793
794
            elif prior_dict['type'] == 'halfnorm':
                lower = prior_dict['loc']
                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = [prior_func(xi) for xi in x]
Gregory Ashton's avatar
Gregory Ashton committed
795
796
797
798
799
            elif prior_dict['type'] == 'neghalfnorm':
                upper = prior_dict['loc']
                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = [prior_func(xi) for xi in x]
800
801
802
            else:
                raise ValueError('Not implemented for prior type {}'.format(
                    prior_dict['type']))
803
            priorln = ax.plot(x, prior, 'C3', label='prior')
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
            ax.set_xlabel(self.theta_symbols[i])

            s = self.samples[:, i]
            while len(s) > 10**4:
                # random downsample to avoid slow calculation of kde
                s = np.random.choice(s, size=int(len(s)/2.))
            kde = gaussian_kde(s)
            ax2 = ax.twinx()
            postln = ax2.plot(x, kde.pdf(x), 'k', label='posterior')
            ax2.set_yticklabels([])
            ax.set_yticklabels([])

        lns = priorln + postln
        labs = [l.get_label() for l in lns]
        axes[0].legend(lns, labs, loc=1, framealpha=0.8)

        fig.savefig('{}/{}_prior_posterior.png'.format(
            self.outdir, self.label))

823
    def plot_cumulative_max(self, **kwargs):
824
825
826
827
        """ Plot the cumulative twoF for the maximum posterior estimate

        See the pyfstat.core.plot_twoF_cumulative function for further details
        """
Gregory Ashton's avatar
Gregory Ashton committed
828
829
830
831
        d, maxtwoF = self.get_max_twoF()
        for key, val in self.theta_prior.iteritems():
            if key not in d:
                d[key] = val
832
833

        if hasattr(self, 'search') is False:
834
            self._initiate_search_object()
835
836
837
        if self.binary is False:
            self.search.plot_twoF_cumulative(
                self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'],
838
                Alpha=d['Alpha'], Delta=d['Delta'],
839
                tstart=self.minStartTime, tend=self.maxStartTime,
840
                **kwargs)
841
842
843
844
845
        else:
            self.search.plot_twoF_cumulative(
                self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'],
                Alpha=d['Alpha'], Delta=d['Delta'], asini=d['asini'],
                period=d['period'], ecc=d['ecc'], argp=d['argp'], tp=d['argp'],
846
                tstart=self.minStartTime, tend=self.maxStartTime, **kwargs)
Gregory Ashton's avatar
Gregory Ashton committed
847

848
    def _generic_lnprior(self, **kwargs):
849
850
851
852
        """ Return a lambda function of the pdf

        Parameters
        ----------
853
        **kwargs:
854
855
856
857
            A dictionary containing 'type' of pdf and shape parameters

        """

Gregory Ashton's avatar
Gregory Ashton committed
858
        def log_of_unif(x, a, b):
859
860
861
862
863
864
865
866
867
868
869
870
871
            above = x < b
            below = x > a
            if type(above) is not np.ndarray:
                if above and below:
                    return -np.log(b-a)
                else:
                    return -np.inf
            else:
                idxs = np.array([all(tup) for tup in zip(above, below)])
                p = np.zeros(len(x)) - np.inf
                p[idxs] = -np.log(b-a)
                return p

Gregory Ashton's avatar
Gregory Ashton committed
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
        def log_of_log10unif(x, log10lower, log10upper):
            log10x = np.log10(x)
            above = log10x < log10upper
            below = log10x > log10lower
            if type(above) is not np.ndarray:
                if above and below:
                    return -np.log(x*np.log(10)*(log10upper-log10lower))
                else:
                    return -np.inf
            else:
                idxs = np.array([all(tup) for tup in zip(above, below)])
                p = np.zeros(len(x)) - np.inf
                p[idxs] = -np.log(x*np.log(10)*(log10upper-log10lower))
                return p

        def log_of_halfnorm(x, loc, scale):
888
            if x < loc:
889
890
891
892
893
894
895
896
897
898
899
900
901
902
                return -np.inf
            else:
                return -0.5*((x-loc)**2/scale**2+np.log(0.5*np.pi*scale**2))

        def cauchy(x, x0, gamma):
            return 1.0/(np.pi*gamma*(1+((x-x0)/gamma)**2))

        def exp(x, x0, gamma):
            if x > x0:
                return np.log(gamma) - gamma*(x - x0)
            else:
                return -np.inf

        if kwargs['type'] == 'unif':
Gregory Ashton's avatar
Gregory Ashton committed
903
904
905
906
            return lambda x: log_of_unif(x, kwargs['lower'], kwargs['upper'])
        if kwargs['type'] == 'log10unif':
            return lambda x: log_of_log10unif(
                x, kwargs['log10lower'], kwargs['log10upper'])
907
        elif kwargs['type'] == 'halfnorm':
Gregory Ashton's avatar
Gregory Ashton committed
908
            return lambda x: log_of_halfnorm(x, kwargs['loc'], kwargs['scale'])
909
        elif kwargs['type'] == 'neghalfnorm':
Gregory Ashton's avatar
Gregory Ashton committed
910
911
            return lambda x: log_of_halfnorm(
                -x, kwargs['loc'], kwargs['scale'])
912
913
914
915
916
917
918
        elif kwargs['type'] == 'norm':
            return lambda x: -0.5*((x - kwargs['loc'])**2/kwargs['scale']**2
                                   + np.log(2*np.pi*kwargs['scale']**2))
        else:
            logging.info("kwargs:", kwargs)
            raise ValueError("Print unrecognise distribution")

919
    def _generate_rv(self, **kwargs):
920
921
922
        dist_type = kwargs.pop('type')
        if dist_type == "unif":
            return np.random.uniform(low=kwargs['lower'], high=kwargs['upper'])
Gregory Ashton's avatar
Gregory Ashton committed
923
924
925
        if dist_type == "log10unif":
            return 10**(np.random.uniform(low=kwargs['log10lower'],
                                          high=kwargs['log10upper']))
926
927
928
929
930
        if dist_type == "norm":
            return np.random.normal(loc=kwargs['loc'], scale=kwargs['scale'])
        if dist_type == "halfnorm":
            return np.abs(np.random.normal(loc=kwargs['loc'],
                                           scale=kwargs['scale']))
931
932
933
        if dist_type == "neghalfnorm":
            return -1 * np.abs(np.random.normal(loc=kwargs['loc'],
                                                scale=kwargs['scale']))
934
935
936
937
938
939
        if dist_type == "lognorm":
            return np.random.lognormal(
                mean=kwargs['loc'], sigma=kwargs['scale'])
        else:
            raise ValueError("dist_type {} unknown".format(dist_type))

940
    def _plot_walkers(self, sampler, symbols=None, alpha=0.8, color="k",
941
942
                      temp=0, lw=0.1, nprod=0, add_det_stat_burnin=False,
                      fig=None, axes=None, xoffset=0, plot_det_stat=False,
943
                      context='ggplot', subtractions=None, labelpad=0.05):
944
945
        """ Plot all the chains from a sampler """

946
947
948
949
950
        if context not in plt.style.available:
            raise ValueError((
                'The requested context {} is not available; please select a'
                ' context from `plt.style.available`').format(context))

951
952
953
        if np.ndim(axes) > 1:
            axes = axes.flatten()

954
955
956
957
958
959
960
961
962
963
964
965
966
        shape = sampler.chain.shape
        if len(shape) == 3:
            nwalkers, nsteps, ndim = shape
            chain = sampler.chain[:, :, :]
        if len(shape) == 4:
            ntemps, nwalkers, nsteps, ndim = shape
            if temp < ntemps:
                logging.info("Plotting temperature {} chains".format(temp))
            else:
                raise ValueError(("Requested temperature {} outside of"
                                  "available range").format(temp))
            chain = sampler.chain[temp, :, :, :]

967
968
        if subtractions is None:
            subtractions = [0 for i in range(ndim)]
969
970
971
        else:
            if len(subtractions) != self.ndim:
                raise ValueError('subtractions must be of length ndim')
972

973
974
975
976
        if plot_det_stat:
            extra_subplots = 1
        else:
            extra_subplots = 0
977
        with plt.style.context((context)):
Gregory Ashton's avatar
Gregory Ashton committed
978
            plt.rcParams['text.usetex'] = True
Gregory Ashton's avatar
Gregory Ashton committed
979
            if fig is None and axes is None:
980
                fig = plt.figure(figsize=(4, 3.0*ndim))
981
982
                ax = fig.add_subplot(ndim+extra_subplots, 1, 1)
                axes = [ax] + [fig.add_subplot(ndim+extra_subplots, 1, i)
Gregory Ashton's avatar
Gregory Ashton committed
983
                               for i in range(2, ndim+1)]
984

Gregory Ashton's avatar
Gregory Ashton committed
985
            idxs = np.arange(chain.shape[1])
986
987
988
989
990
            burnin_idx = chain.shape[1] - nprod
            if hasattr(self, 'convergence_idx'):
                convergence_idx = self.convergence_idx
            else:
                convergence_idx = burnin_idx
991
992
            if ndim > 1:
                for i in range(ndim):
993
                    axes[i].ticklabel_format(useOffset=False, axis='y')
Gregory Ashton's avatar
Gregory Ashton committed
994
                    cs = chain[:, :, i].T
995
                    if burnin_idx > 0:
996
997
                        axes[i].plot(xoffset+idxs[:convergence_idx+1],
                                     cs[:convergence_idx+1]-subtractions[i],
998
                                     color="C3", alpha=alpha,
Gregory Ashton's avatar
Gregory Ashton committed
999
                                     lw=lw)
1000
                        axes[i].axvline(xoffset+convergence_idx,
1001
                                        color='k', ls='--', lw=0.25)
1002
1003
                    axes[i].plot(xoffset+idxs[burnin_idx:],
                                 cs[burnin_idx:]-subtractions[i],
Gregory Ashton's avatar
Gregory Ashton committed
1004
                                 color="k", alpha=alpha, lw=lw)
Gregory Ashton's avatar
Gregory Ashton committed
1005
1006

                    axes[i].set_xlim(0, xoffset+idxs[-1])
1007
                    if symbols:
1008
                        if subtractions[i] == 0:
1009
                            axes[i].set_ylabel(symbols[i], labelpad=labelpad)
1010
1011
                        else:
                            axes[i].set_ylabel(
1012
1013
                                symbols[i]+'$-$'+symbols[i]+'$_0$',
                                labelpad=labelpad)
1014

1015
1016
                    if hasattr(self, 'convergence_diagnostic'):
                        ax = axes[i].twinx()
1017
1018
                        axes[i].set_zorder(ax.get_zorder()+1)
                        axes[i].patch.set_visible(False)
1019
1020
                        c_x = np.array(self.convergence_diagnosticx)
                        c_y = np.array(self.convergence_diagnostic)
1021
                        break_idx = np.argmin(np.abs(c_x - burnin_idx))
1022
1023
1024
1025
                        ax.plot(c_x[:break_idx], c_y[:break_idx, i], '-C0',
                                zorder=-10)
                        ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-C0',
                                zorder=-10)
1026
1027
1028
1029
                        if self.convergence_test_type == 'autocorr':
                            ax.set_ylabel(r'$\tau_\mathrm{exp}$')
                        elif self.convergence_test_type == 'GR':
                            ax.set_ylabel('PSRF')
1030
                        ax.ticklabel_format(useOffset=False)
1031
            else:
Gregory Ashton's avatar
Gregory Ashton committed
1032
                axes[0].ticklabel_format(useOffset=False, axis='y')
Gregory Ashton's avatar
Gregory Ashton committed
1033
                cs = chain[:, :, temp].T
Gregory Ashton's avatar
Gregory Ashton committed
1034
1035
                if burnin_idx:
                    axes[0].plot(idxs[:burnin_idx], cs[:burnin_idx],
1036
                                 color="C3", alpha=alpha, lw=lw)
Gregory Ashton's avatar
Gregory Ashton committed
1037
1038
1039
                axes[0].plot(idxs[burnin_idx:], cs[burnin_idx:], color="k",
                             alpha=alpha, lw=lw)
                if symbols:
1040
                    axes[0].set_ylabel(symbols[0], labelpad=labelpad)
1041

Gregory Ashton's avatar
Gregory Ashton committed
1042
1043
            axes[-1].set_xlabel(r'$\textrm{Number of steps}$', labelpad=0.2)

1044
            if plot_det_stat:
1045
1046
1047
                if len(axes) == ndim:
                    axes.append(fig.add_subplot(ndim+1, 1, ndim+1))

1048
1049
1050
                lnl = sampler.lnlikelihood[temp, :, :]
                if burnin_idx and add_det_stat_burnin:
                    burn_in_vals = lnl[:, :burnin_idx].flatten()
1051
                    try:
1052
1053
1054
1055
                        twoF_burnin = (burn_in_vals[~np.isnan(burn_in_vals)]
                                       - self.likelihoodcoef)
                        axes[-1].hist(twoF_burnin, bins=50, histtype='step',
                                      color='C3')
1056
1057
1058
1059
                    except ValueError:
                        logging.info('Det. Stat. hist failed, most likely all '
                                     'values where the same')
                        pass
1060
                else:
1061
                    twoF_burnin = []
1062
                prod_vals = lnl[:, burnin_idx:].flatten()
1063
                try:
1064
1065
                    twoF = prod_vals[~np.isnan(prod_vals)]-self.likelihoodcoef
                    axes[-1].hist(twoF, bins=50, histtype='step', color='k')
1066
1067
1068
1069
                except ValueError:
                    logging.info('Det. Stat. hist failed, most likely all '
                                 'values where the same')
                    pass
1070
1071
1072
1073
1074
                if self.BSGL:
                    axes[-1].set_xlabel(