mcmc_based_searches.py 92.2 KB
Newer Older
Gregory Ashton's avatar
Gregory Ashton committed
1
2
""" Searches using MCMC-based methods """

3
import sys
Gregory Ashton's avatar
Gregory Ashton committed
4
import os
5
import copy
Gregory Ashton's avatar
Gregory Ashton committed
6
import logging
7
from collections import OrderedDict
8
9
10
11
12
13
14
15

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import emcee
import corner
import dill as pickle

16
import core
Gregory Ashton's avatar
Gregory Ashton committed
17
from core import tqdm, args, earth_ephem, sun_ephem
18
from optimal_setup_functions import get_V_estimate
Gregory Ashton's avatar
Gregory Ashton committed
19
20
from optimal_setup_functions import get_optimal_setup
import helper_functions
21
22


23
class MCMCSearch(core.BaseSearchClass):
Gregory Ashton's avatar
Gregory Ashton committed
24
    """ MCMC search using ComputeFstat"""
25
26

    symbol_dictionary = dict(
27
        F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', Alpha=r'$\alpha$',
28
29
        Delta='$\delta$', asini='asini', period='P', ecc='ecc', tp='tp',
        argp='argp')
30
    unit_dictionary = dict(
31
32
        F0='Hz', F1='Hz/s', F2='Hz/s$^2$', Alpha=r'rad', Delta='rad',
        asini='', period='s', ecc='', tp='', argp='')
33
34
35
    rescale_dictionary = {}


Gregory Ashton's avatar
Gregory Ashton committed
36
    @helper_functions.initializer
Gregory Ashton's avatar
Gregory Ashton committed
37
38
    def __init__(self, label, outdir, theta_prior, tref, minStartTime,
                 maxStartTime, sftfilepath=None, nsteps=[100, 100],
39
                 nwalkers=100, ntemps=1, log10temperature_min=-5,
40
                 theta_initial=None, scatter_val=1e-10, rhohatmax=1000,
41
                 binary=False, BSGL=False, minCoverFreq=None,
42
                 maxCoverFreq=None, detectors=None, earth_ephem=None,
43
                 sun_ephem=None, injectSources=None, assumeSqrtSX=None):
44
45
46
47
        """
        Parameters
        label, outdir: str
            A label and directory to read/write data from/to
48
        sftfilepath: str
49
50
            Pattern to match SFTs using wildcards (*?) and ranges [0-9];
            mutiple patterns can be given separated by colons.
51
        theta_prior: dict
52
53
54
55
            Dictionary of priors and fixed values for the search parameters.
            For each parameters (key of the dict), if it is to be held fixed
            the value should be the constant float, if it is be searched, the
            value should be a dictionary of the prior.
56
57
58
59
        theta_initial: dict, array, (None)
            Either a dictionary of distribution about which to distribute the
            initial walkers about, an array (from which the walkers will be
            scattered by scatter_val, or  None in which case the prior is used.
60
        tref, minStartTime, maxStartTime: int
61
62
63
64
65
66
            GPS seconds of the reference time, start time and end time
        nsteps: list (m,)
            List specifying the number of steps to take, the last two entries
            give the nburn and nprod of the 'production' run, all entries
            before are for iterative initialisation steps (usually just one)
            e.g. [1000, 1000, 500].
67
68
69
70
71
72
        nwalkers, ntemps: int,
            The number of walkers and temperates to use in the parallel
            tempered PTSampler.
        log10temperature_min float < 0
            The  log_10(tmin) value, the set of betas passed to PTSampler are
            generated from np.logspace(0, log10temperature_min, ntemps).
73
74
75
76
        rhohatmax: float
            Upper bound for the SNR scale parameter (required to normalise the
            Bayes factor) - this needs to be carefully set when using the
            evidence.
77
78
        binary: Bool
            If true, search over binary parameters
79
        detectors: str
80
81
            Two character reference to the data to use, specify None for no
            contraint.
82
83
84
85
86
87
88
89
90
91
        minCoverFreq, maxCoverFreq: float
            Minimum and maximum instantaneous frequency which will be covered
            over the SFT time span as passed to CreateFstatInput
        earth_ephem, sun_ephem: str
            Paths of the two files containing positions of Earth and Sun,
            respectively at evenly spaced times, as passed to CreateFstatInput
            If None defaults defined in BaseSearchClass will be used

        """

Gregory Ashton's avatar
Gregory Ashton committed
92
93
        if os.path.isdir(outdir) is False:
            os.mkdir(outdir)
94
        self._add_log_file()
Gregory Ashton's avatar
Gregory Ashton committed
95
96
        logging.info(
            'Set-up MCMC search for model {} on data {}'.format(
97
                self.label, self.sftfilepath))
98
        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
99
        self._unpack_input_theta()
100
        self.ndim = len(self.theta_keys)
101
102
103
104
        if self.log10temperature_min:
            self.betas = np.logspace(0, self.log10temperature_min, self.ntemps)
        else:
            self.betas = None
105

106
107
108
109
110
111
112
113
        if earth_ephem is None:
            self.earth_ephem = self.earth_ephem_default
        if sun_ephem is None:
            self.sun_ephem = self.sun_ephem_default

        if args.clean and os.path.isfile(self.pickle_path):
            os.rename(self.pickle_path, self.pickle_path+".old")

114
115
        self.lnlikelihoodcoef = np.log(70./self.rhohatmax**4)

116
        self._log_input()
117

118
    def _log_input(self):
119
        logging.info('theta_prior = {}'.format(self.theta_prior))
120
        logging.info('nwalkers={}'.format(self.nwalkers))
121
122
123
124
        logging.info('scatter_val = {}'.format(self.scatter_val))
        logging.info('nsteps = {}'.format(self.nsteps))
        logging.info('ntemps = {}'.format(self.ntemps))
        logging.info('log10temperature_min = {}'.format(
125
            self.log10temperature_min))
126

127
    def _initiate_search_object(self):
128
        logging.info('Setting up search object')
129
        self.search = core.ComputeFstat(
130
131
132
            tref=self.tref, sftfilepath=self.sftfilepath,
            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
            earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
133
            detectors=self.detectors, BSGL=self.BSGL, transient=False,
134
            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
135
136
            binary=self.binary, injectSources=self.injectSources,
            assumeSqrtSX=self.assumeSqrtSX)
137
138

    def logp(self, theta_vals, theta_prior, theta_keys, search):
139
        H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
140
141
142
143
144
145
             zip(theta_vals, theta_keys)]
        return np.sum(H)

    def logl(self, theta, search):
        for j, theta_i in enumerate(self.theta_idxs):
            self.fixed_theta[theta_i] = theta[j]
146
147
        FS = search.compute_fullycoherent_det_stat_single_point(
            *self.fixed_theta)
148
        return FS + self.lnlikelihoodcoef
149

150
    def _unpack_input_theta(self):
151
        full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']
152
153
154
        if self.binary:
            full_theta_keys += [
                'asini', 'period', 'ecc', 'tp', 'argp']
155
156
        full_theta_keys_copy = copy.copy(full_theta_keys)

157
158
        full_theta_symbols = ['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
                              r'$\delta$']
159
160
        if self.binary:
            full_theta_symbols += [
161
                'asini', 'period', 'ecc', 'tp', 'argp']
162

163
164
        self.theta_keys = []
        fixed_theta_dict = {}
165
        for key, val in self.theta_prior.iteritems():
166
167
            if type(val) is dict:
                fixed_theta_dict[key] = 0
Gregory Ashton's avatar
Gregory Ashton committed
168
                self.theta_keys.append(key)
169
170
171
172
173
174
            elif type(val) in [float, int, np.float64]:
                fixed_theta_dict[key] = val
            else:
                raise ValueError(
                    'Type {} of {} in theta not recognised'.format(
                        type(val), key))
Gregory Ashton's avatar
Gregory Ashton committed
175
            full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

        if len(full_theta_keys_copy) > 0:
            raise ValueError(('Input dictionary `theta` is missing the'
                              'following keys: {}').format(
                                  full_theta_keys_copy))

        self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
        self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
        self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]

        idxs = np.argsort(self.theta_idxs)
        self.theta_idxs = [self.theta_idxs[i] for i in idxs]
        self.theta_symbols = [self.theta_symbols[i] for i in idxs]
        self.theta_keys = [self.theta_keys[i] for i in idxs]

191
    def _check_initial_points(self, p0):
192
193
194
195
196
197
198
199
200
201
202
203
204
        for nt in range(self.ntemps):
            logging.info('Checking temperature {} chains'.format(nt))
            initial_priors = np.array([
                self.logp(p, self.theta_prior, self.theta_keys, self.search)
                for p in p0[nt]])
            number_of_initial_out_of_bounds = sum(initial_priors == -np.inf)

            if number_of_initial_out_of_bounds > 0:
                logging.warning(
                    'Of {} initial values, {} are -np.inf due to the prior'
                    .format(len(initial_priors),
                            number_of_initial_out_of_bounds))

205
                p0 = self._generate_new_p0_to_fix_initial_points(
206
207
                    p0, nt, initial_priors)

208
    def _generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors):
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        logging.info('Attempting to correct intial values')
        idxs = np.arange(self.nwalkers)[initial_priors == -np.inf]
        count = 0
        while sum(initial_priors == -np.inf) > 0 and count < 100:
            for j in idxs:
                p0[nt][j] = (p0[nt][np.random.randint(0, self.nwalkers)]*(
                             1+np.random.normal(0, 1e-10, self.ndim)))
            initial_priors = np.array([
                self.logp(p, self.theta_prior, self.theta_keys,
                          self.search)
                for p in p0[nt]])
            count += 1

        if sum(initial_priors == -np.inf) > 0:
            logging.info('Failed to fix initial priors')
        else:
            logging.info('Suceeded to fix initial priors')

        return p0
228

229
    def _OLD_run_sampler_with_progress_bar(self, sampler, ns, p0):
230
231
        for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
            pass
Gregory Ashton's avatar
Gregory Ashton committed
232
233
        return sampler

234
235
    def setup_convergence_testing(
            self, convergence_period=10, convergence_length=10,
236
            convergence_burnin_fraction=0.25, convergence_threshold_number=10,
237
            convergence_threshold=1.2, convergence_prod_threshold=2,
238
            convergence_plot_upper_lim=2, convergence_early_stopping=True):
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        """
        If called, convergence testing is used during the MCMC simulation

        This uses the Gelmanr-Rubin statistic based on the ratio of between and
        within walkers variance. The original statistic was developed for
        multiple (independent) MCMC simulations, in this context we simply use
        the walkers

        Parameters
        ----------
        convergence_period: int
            period (in number of steps) at which to test convergence
        convergence_length: int
            number of steps to use in testing convergence - this should be
            large enough to measure the variance, but if it is too long
            this will result in incorect early convergence tests
        convergence_burnin_fraction: float [0, 1]
            the fraction of the burn-in period after which to start testing
        convergence_threshold_number: int
            the number of consecutive times where the test passes after which
            to break the burn-in and go to production
        convergence_threshold: float
            the threshold to use in diagnosing convergence. Gelman & Rubin
            recomend a value of 1.2, 1.1 for strict convergence
        convergence_prod_threshold: float
            the threshold to test the production values with
265
266
        convergence_plot_upper_lim: float
            the upper limit to use in the diagnostic plot
267
268
        convergence_early_stopping: bool
            if true, stop the burnin early if convergence is reached
269
        """
270
271
272
273
274
275
276

        if convergence_length > convergence_period:
            raise ValueError('convergence_length must be < convergence_period')
        logging.info('Setting up convergence testing')
        self.convergence_length = convergence_length
        self.convergence_period = convergence_period
        self.convergence_burnin_fraction = convergence_burnin_fraction
277
        self.convergence_prod_threshold = convergence_prod_threshold
278
279
280
281
282
        self.convergence_diagnostic = []
        self.convergence_diagnosticx = []
        self.convergence_threshold_number = convergence_threshold_number
        self.convergence_threshold = convergence_threshold
        self.convergence_number = 0
283
        self.convergence_plot_upper_lim = convergence_plot_upper_lim
284
        self.convergence_early_stopping = convergence_early_stopping
285

286
    def _get_convergence_statistic(self, i, sampler):
287
        s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
288
289
290
        N = float(self.convergence_length)
        M = float(self.nwalkers)
        W = np.mean(np.var(s, axis=1), axis=0)
291
292
        per_walker_mean = np.mean(s, axis=1)
        mean = np.mean(per_walker_mean, axis=0)
293
294
295
        B = N / (M-1.) * np.sum((per_walker_mean-mean)**2, axis=0)
        Vhat = (N-1)/N * W + (M+1)/(M*N) * B
        c = Vhat/W
296
        self.convergence_diagnostic.append(c)
297
        self.convergence_diagnosticx.append(i - self.convergence_length/2)
298
299
        return c

300
    def _burnin_convergence_test(self, i, sampler, nburn):
301
302
        if i < self.convergence_burnin_fraction*nburn:
            return False
303
        if np.mod(i+1, self.convergence_period) != 0:
304
            return False
305
        c = self._get_convergence_statistic(i, sampler)
306
307
        if np.all(c < self.convergence_threshold):
            self.convergence_number += 1
308
309
        else:
            self.convergence_number = 0
310
311
        if self.convergence_early_stopping:
            return self.convergence_number > self.convergence_threshold_number
312

313
    def _prod_convergence_test(self, i, sampler, nburn):
314
315
316
        testA = i > nburn + self.convergence_length
        testB = np.mod(i+1, self.convergence_period) == 0
        if testA and testB:
317
            self._get_convergence_statistic(i, sampler)
318

319
    def _check_production_convergence(self, k):
320
321
322
323
324
325
326
327
        bools = np.any(
            np.array(self.convergence_diagnostic)[k:, :]
            > self.convergence_prod_threshold, axis=1)
        if np.any(bools):
            logging.warning(
                '{} convergence tests in the production run of {} failed'
                .format(np.sum(bools), len(bools)))

328
    def _run_sampler(self, sampler, p0, nprod=0, nburn=0):
329
        if hasattr(self, 'convergence_period'):
330
331
332
333
            logging.info('Running {} burn-in steps with convergence testing'
                         .format(nburn))
            iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
            for i, output in enumerate(iterator):
334
                if self._burnin_convergence_test(i, sampler, nburn):
335
336
337
338
339
340
341
342
343
344
345
                    logging.info(
                        'Converged at {} before max number {} of steps reached'
                        .format(i, nburn))
                    self.convergence_idx = i
                    break
            iterator.close()
            logging.info('Running {} production steps'.format(nprod))
            j = nburn
            k = len(self.convergence_diagnostic)
            for result in tqdm(sampler.sample(output[0], iterations=nprod),
                               total=nprod):
346
                self._prod_convergence_test(j, sampler, nburn)
347
                j += 1
348
            self._check_production_convergence(k)
349
350
351
352
353
354
            return sampler
        else:
            for result in tqdm(sampler.sample(p0, iterations=nburn+nprod),
                               total=nburn+nprod):
                pass
            return sampler
355

356
    def run(self, proposal_scale_factor=2, create_plots=True, **kwargs):
357
        """ Run the MCMC simulatation """
358

359
        self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use()
360
361
362
        if self.old_data_is_okay_to_use is True:
            logging.warning('Using saved data from {}'.format(
                self.pickle_path))
363
            d = self.get_saved_data_dictionary()
364
365
366
            self.samples = d['samples']
            self.lnprobs = d['lnprobs']
            self.lnlikes = d['lnlikes']
367
            self.all_lnlikelihood = d['all_lnlikelihood']
368
369
            return

370
        self._initiate_search_object()
371
372
373
374

        sampler = emcee.PTSampler(
            self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
            logpargs=(self.theta_prior, self.theta_keys, self.search),
375
            loglargs=(self.search,), betas=self.betas, a=proposal_scale_factor)
376

377
378
379
        p0 = self._generate_initial_p0()
        p0 = self._apply_corrections_to_p0(p0)
        self._check_initial_points(p0)
380
381
382
383

        ninit_steps = len(self.nsteps) - 2
        for j, n in enumerate(self.nsteps[:-2]):
            logging.info('Running {}/{} initialisation with {} steps'.format(
Gregory Ashton's avatar
Gregory Ashton committed
384
                j, ninit_steps, n))
385
            sampler = self._run_sampler(sampler, p0, nburn=n)
386
387
            logging.info("Mean acceptance fraction: {}"
                         .format(np.mean(sampler.acceptance_fraction, axis=1)))
388
389
390
            if self.ntemps > 1:
                logging.info("Tswap acceptance fraction: {}"
                             .format(sampler.tswap_acceptance_fraction))
391
            if create_plots:
392
                fig, axes = self._plot_walkers(sampler,
393
394
395
396
                                              symbols=self.theta_symbols,
                                              **kwargs)
                fig.tight_layout()
                fig.savefig('{}/{}_init_{}_walkers.png'.format(
397
                    self.outdir, self.label, j), dpi=400)
398

399
400
401
            p0 = self._get_new_p0(sampler)
            p0 = self._apply_corrections_to_p0(p0)
            self._check_initial_points(p0)
402
403
            sampler.reset()

Gregory Ashton's avatar
Gregory Ashton committed
404
405
406
407
        if len(self.nsteps) > 1:
            nburn = self.nsteps[-2]
        else:
            nburn = 0
408
409
410
        nprod = self.nsteps[-1]
        logging.info('Running final burn and prod with {} steps'.format(
            nburn+nprod))
411
        sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
412
413
        logging.info("Mean acceptance fraction: {}"
                     .format(np.mean(sampler.acceptance_fraction, axis=1)))
414
415
416
        if self.ntemps > 1:
            logging.info("Tswap acceptance fraction: {}"
                         .format(sampler.tswap_acceptance_fraction))
417

418
        if create_plots:
419
            fig, axes = self._plot_walkers(sampler, symbols=self.theta_symbols,
420
                                          nprod=nprod, **kwargs)
421
422
423
            fig.tight_layout()
            fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
                        dpi=200)
424
425
426
427

        samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
        lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1))
        lnlikes = sampler.lnlikelihood[0, :, nburn:].reshape((-1))
428
        all_lnlikelihood = sampler.lnlikelihood[:, :, nburn:]
429
430
431
        self.samples = samples
        self.lnprobs = lnprobs
        self.lnlikes = lnlikes
432
433
        self.all_lnlikelihood = all_lnlikelihood
        self._save_data(sampler, samples, lnprobs, lnlikes, all_lnlikelihood)
434

435
    def _get_rescale_multiplier_for_key(self, key):
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        """ Get the rescale multiplier from the rescale_dictionary

        Can either be a float, a string (in which case it is interpretted as
        a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent
        in which case 0 is returned
        """
        if key not in self.rescale_dictionary:
            return 1

        if 'multiplier' in self.rescale_dictionary[key]:
            val = self.rescale_dictionary[key]['multiplier']
            if type(val) == str:
                if hasattr(self, val):
                    multiplier = getattr(
                        self, self.rescale_dictionary[key]['multiplier'])
                else:
                    raise ValueError(
                        "multiplier {} not a class attribute".format(val))
            else:
                multiplier = val
        else:
            multiplier = 1
        return multiplier

460
    def _get_rescale_subtractor_for_key(self, key):
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        """ Get the rescale subtractor from the rescale_dictionary

        Can either be a float, a string (in which case it is interpretted as
        a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent
        in which case 0 is returned
        """
        if key not in self.rescale_dictionary:
            return 0

        if 'subtractor' in self.rescale_dictionary[key]:
            val = self.rescale_dictionary[key]['subtractor']
            if type(val) == str:
                if hasattr(self, val):
                    subtractor = getattr(
                        self, self.rescale_dictionary[key]['subtractor'])
                else:
                    raise ValueError(
                        "subtractor {} not a class attribute".format(val))
            else:
                subtractor = val
        else:
            subtractor = 0
        return subtractor

485
    def _scale_samples(self, samples, theta_keys):
486
        """ Scale the samples using the rescale_dictionary """
487
488
489
490
        for key in theta_keys:
            if key in self.rescale_dictionary:
                idx = theta_keys.index(key)
                s = samples[:, idx]
491
                subtractor = self._get_rescale_subtractor_for_key(key)
492
                s = s - subtractor
493
                multiplier = self._get_rescale_multiplier_for_key(key)
494
                s *= multiplier
495
496
                samples[:, idx] = s

497
498
        return samples

499
    def _get_labels(self):
500
        """ Combine the units, symbols and rescaling to give labels """
501

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        labels = []
        for key in self.theta_keys:
            label = None
            s = self.symbol_dictionary[key]
            s.replace('_{glitch}', r'_\textrm{glitch}')
            u = self.unit_dictionary[key]
            if key in self.rescale_dictionary:
                if 'symbol' in self.rescale_dictionary[key]:
                    s = self.rescale_dictionary[key]['symbol']
                if 'label' in self.rescale_dictionary[key]:
                    label = self.rescale_dictionary[key]['label']
                if 'unit' in self.rescale_dictionary[key]:
                    u = self.rescale_dictionary[key]['unit']
            if label is None:
                label = '{} \n [{}]'.format(s, u)
            labels.append(label)
        return labels
519

520
521
    def plot_corner(self, figsize=(7, 7), add_prior=False, nstds=None,
                    label_offset=0.4, dpi=300, rc_context={},
522
                    tglitch_ratio=False, fig_and_axes=None, save_fig=True,
523
                    **kwargs):
524
525
526
527
528
529
530
531
532
        """ Generate a corner plot of the posterior

        Using the `corner` package (https://pypi.python.org/pypi/corner/),
        generate estimates of the posterior from the production samples.

        Parameters
        ----------
        figsize: tuple (7, 7)
            Figure size in inches (passed to plt.subplots)
533
534
535
        add_prior: bool, str
            If true, plot the prior as a red line. If 'full' then for uniform
            priors plot the full extent of the prior.
536
537
538
539
540
541
542
543
544
545
546
547
548
549
        nstds: float
            The number of standard deviations to plot centered on the mean
        label_offset: float
            Offset the labels from the plot: useful to precent overlapping the
            tick labels with the axis labels
        dpi: int
            Passed to plt.savefig
        rc_context: dict
            Dictionary of rc values to set while generating the figure (see
            matplotlib rc for more details)
        tglitch_ratio: bool
            If true, and tglitch is a parameter, plot posteriors as the
            fractional time at which the glitch occurs instead of the actual
            time
550
551
552
553
554
        fig_and_axes: tuple
            fig and axes to plot on, the axes must be of the right shape,
            namely (ndim, ndim)
        save_fig: bool
            If true, save the figure, else return the fig, axes
555

556
        Note: kwargs are passed on to corner.corner
557
558

        """
559

560
561
562
563
        if 'truths' in kwargs and len(kwargs['truths']) != self.ndim:
            logging.warning('len(Truths) != ndim, Truths will be ignored')
            kwargs['truths'] = None

Gregory Ashton's avatar
Gregory Ashton committed
564
565
        if self.ndim < 2:
            with plt.rc_context(rc_context):
566
567
568
569
                if fig_and_axes is None:
                    fig, ax = plt.subplots(figsize=figsize)
                else:
                    fig, ax = fig_and_axes
Gregory Ashton's avatar
Gregory Ashton committed
570
571
572
573
574
575
576
                ax.hist(self.samples, bins=50, histtype='stepfilled')
                ax.set_xlabel(self.theta_symbols[0])

            fig.savefig('{}/{}_corner.png'.format(
                self.outdir, self.label), dpi=dpi)
            return

577
        with plt.rc_context(rc_context):
578
579
580
581
582
            if fig_and_axes is None:
                fig, axes = plt.subplots(self.ndim, self.ndim,
                                         figsize=figsize)
            else:
                fig, axes = fig_and_axes
583
584

            samples_plt = copy.copy(self.samples)
585
            labels = self._get_labels()
586

587
            samples_plt = self._scale_samples(samples_plt, self.theta_keys)
588
589
590
591
592

            if tglitch_ratio:
                for j, k in enumerate(self.theta_keys):
                    if k == 'tglitch':
                        s = samples_plt[:, j]
593
594
595
                        samples_plt[:, j] = (
                            s - self.minStartTime)/(
                                self.maxStartTime - self.minStartTime)
596
                        labels[j] = r'$R_{\textrm{glitch}}$'
597
598
599
600
601
602
603

            if type(nstds) is int and 'range' not in kwargs:
                _range = []
                for j, s in enumerate(samples_plt.T):
                    median = np.median(s)
                    std = np.std(s)
                    _range.append((median - nstds*std, median + nstds*std))
604
605
            elif 'range' in kwargs:
                _range = kwargs.pop('range')
606
607
608
            else:
                _range = None

609
610
611
612
            hist_kwargs = kwargs.pop('hist_kwargs', dict())
            if 'normed' not in hist_kwargs:
                hist_kwargs['normed'] = True

613
            fig_triangle = corner.corner(samples_plt,
614
                                         labels=labels,
615
616
617
618
619
620
621
622
623
                                         fig=fig,
                                         bins=50,
                                         max_n_ticks=4,
                                         plot_contours=True,
                                         plot_datapoints=True,
                                         label_kwargs={'fontsize': 8},
                                         data_kwargs={'alpha': 0.1,
                                                      'ms': 0.5},
                                         range=_range,
624
                                         hist_kwargs=hist_kwargs,
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
                                         **kwargs)

            axes_list = fig_triangle.get_axes()
            axes = np.array(axes_list).reshape(self.ndim, self.ndim)
            plt.draw()
            for ax in axes[:, 0]:
                ax.yaxis.set_label_coords(-label_offset, 0.5)
            for ax in axes[-1, :]:
                ax.xaxis.set_label_coords(0.5, -label_offset)
            for ax in axes_list:
                ax.set_rasterized(True)
                ax.set_rasterization_zorder(-10)
            plt.tight_layout(h_pad=0.0, w_pad=0.0)
            fig.subplots_adjust(hspace=0.05, wspace=0.05)

            if add_prior:
641
                self._add_prior_to_corner(axes, self.samples, add_prior)
642

643
644
645
646
647
            if save_fig:
                fig_triangle.savefig('{}/{}_corner.png'.format(
                    self.outdir, self.label), dpi=dpi)
            else:
                return fig, axes
648

649
    def _add_prior_to_corner(self, axes, samples, add_prior):
650
651
652
        for i, key in enumerate(self.theta_keys):
            ax = axes[i][i]
            s = samples[:, i]
653
654
655
656
657
658
659
660
661
662
            lnprior = self._generic_lnprior(**self.theta_prior[key])
            if add_prior == 'full' and self.theta_prior[key]['type'] == 'unif':
                lower = self.theta_prior[key]['lower']
                upper = self.theta_prior[key]['upper']
                r = upper-lower
                xlim = [lower-0.05*r, upper+0.05*r]
                x = np.linspace(xlim[0], xlim[1], 1000)
            else:
                xlim = ax.get_xlim()
                x = np.linspace(s.min(), s.max(), 1000)
663
664
            multiplier = self._get_rescale_multiplier_for_key(key)
            subtractor = self._get_rescale_subtractor_for_key(key)
665
666
667
668
669
670
671
672
            ax.plot((x-subtractor)*multiplier,
                    [np.exp(lnprior(xi)) for xi in x], '-C3',
                    label='prior')

            for j in range(i, self.ndim):
                axes[j][i].set_xlim(xlim[0], xlim[1])
            for k in range(0, i):
                axes[i][k].set_ylim(xlim[0], xlim[1])
673

674
675
676
677
678
679
680
681
    def plot_prior_posterior(self, normal_stds=2):
        """ Plot the posterior in the context of the prior """
        fig, axes = plt.subplots(nrows=self.ndim, figsize=(8, 4*self.ndim))
        N = 1000
        from scipy.stats import gaussian_kde

        for i, (ax, key) in enumerate(zip(axes, self.theta_keys)):
            prior_dict = self.theta_prior[key]
682
            prior_func = self._generic_lnprior(**prior_dict)
683
684
685
686
687
            if prior_dict['type'] == 'unif':
                x = np.linspace(prior_dict['lower'], prior_dict['upper'], N)
                prior = prior_func(x)
                prior[0] = 0
                prior[-1] = 0
Gregory Ashton's avatar
Gregory Ashton committed
688
689
690
691
692
            elif prior_dict['type'] == 'log10unif':
                upper = prior_dict['log10upper']
                lower = prior_dict['log10lower']
                x = np.linspace(lower, upper, N)
                prior = [prior_func(xi) for xi in x]
693
694
695
696
697
            elif prior_dict['type'] == 'norm':
                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = prior_func(x)
698
699
700
701
702
            elif prior_dict['type'] == 'halfnorm':
                lower = prior_dict['loc']
                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = [prior_func(xi) for xi in x]
Gregory Ashton's avatar
Gregory Ashton committed
703
704
705
706
707
            elif prior_dict['type'] == 'neghalfnorm':
                upper = prior_dict['loc']
                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = [prior_func(xi) for xi in x]
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
            else:
                raise ValueError('Not implemented for prior type {}'.format(
                    prior_dict['type']))
            priorln = ax.plot(x, prior, 'r', label='prior')
            ax.set_xlabel(self.theta_symbols[i])

            s = self.samples[:, i]
            while len(s) > 10**4:
                # random downsample to avoid slow calculation of kde
                s = np.random.choice(s, size=int(len(s)/2.))
            kde = gaussian_kde(s)
            ax2 = ax.twinx()
            postln = ax2.plot(x, kde.pdf(x), 'k', label='posterior')
            ax2.set_yticklabels([])
            ax.set_yticklabels([])

        lns = priorln + postln
        labs = [l.get_label() for l in lns]
        axes[0].legend(lns, labs, loc=1, framealpha=0.8)

        fig.savefig('{}/{}_prior_posterior.png'.format(
            self.outdir, self.label))

731
    def plot_cumulative_max(self, **kwargs):
732
733
734
735
        """ Plot the cumulative twoF for the maximum posterior estimate

        See the pyfstat.core.plot_twoF_cumulative function for further details
        """
Gregory Ashton's avatar
Gregory Ashton committed
736
737
738
739
        d, maxtwoF = self.get_max_twoF()
        for key, val in self.theta_prior.iteritems():
            if key not in d:
                d[key] = val
740
741

        if hasattr(self, 'search') is False:
742
            self._initiate_search_object()
743
744
745
        if self.binary is False:
            self.search.plot_twoF_cumulative(
                self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'],
746
                Alpha=d['Alpha'], Delta=d['Delta'],
747
                tstart=self.minStartTime, tend=self.maxStartTime,
748
                **kwargs)
749
750
751
752
753
        else:
            self.search.plot_twoF_cumulative(
                self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'],
                Alpha=d['Alpha'], Delta=d['Delta'], asini=d['asini'],
                period=d['period'], ecc=d['ecc'], argp=d['argp'], tp=d['argp'],
754
                tstart=self.minStartTime, tend=self.maxStartTime, **kwargs)
Gregory Ashton's avatar
Gregory Ashton committed
755

756
    def _generic_lnprior(self, **kwargs):
757
758
759
760
761
762
763
764
765
        """ Return a lambda function of the pdf

        Parameters
        ----------
        kwargs: dict
            A dictionary containing 'type' of pdf and shape parameters

        """

Gregory Ashton's avatar
Gregory Ashton committed
766
        def log_of_unif(x, a, b):
767
768
769
770
771
772
773
774
775
776
777
778
779
            above = x < b
            below = x > a
            if type(above) is not np.ndarray:
                if above and below:
                    return -np.log(b-a)
                else:
                    return -np.inf
            else:
                idxs = np.array([all(tup) for tup in zip(above, below)])
                p = np.zeros(len(x)) - np.inf
                p[idxs] = -np.log(b-a)
                return p

Gregory Ashton's avatar
Gregory Ashton committed
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
        def log_of_log10unif(x, log10lower, log10upper):
            log10x = np.log10(x)
            above = log10x < log10upper
            below = log10x > log10lower
            if type(above) is not np.ndarray:
                if above and below:
                    return -np.log(x*np.log(10)*(log10upper-log10lower))
                else:
                    return -np.inf
            else:
                idxs = np.array([all(tup) for tup in zip(above, below)])
                p = np.zeros(len(x)) - np.inf
                p[idxs] = -np.log(x*np.log(10)*(log10upper-log10lower))
                return p

        def log_of_halfnorm(x, loc, scale):
796
            if x < loc:
797
798
799
800
801
802
803
804
805
806
807
808
809
810
                return -np.inf
            else:
                return -0.5*((x-loc)**2/scale**2+np.log(0.5*np.pi*scale**2))

        def cauchy(x, x0, gamma):
            return 1.0/(np.pi*gamma*(1+((x-x0)/gamma)**2))

        def exp(x, x0, gamma):
            if x > x0:
                return np.log(gamma) - gamma*(x - x0)
            else:
                return -np.inf

        if kwargs['type'] == 'unif':
Gregory Ashton's avatar
Gregory Ashton committed
811
812
813
814
            return lambda x: log_of_unif(x, kwargs['lower'], kwargs['upper'])
        if kwargs['type'] == 'log10unif':
            return lambda x: log_of_log10unif(
                x, kwargs['log10lower'], kwargs['log10upper'])
815
        elif kwargs['type'] == 'halfnorm':
Gregory Ashton's avatar
Gregory Ashton committed
816
            return lambda x: log_of_halfnorm(x, kwargs['loc'], kwargs['scale'])
817
        elif kwargs['type'] == 'neghalfnorm':
Gregory Ashton's avatar
Gregory Ashton committed
818
819
            return lambda x: log_of_halfnorm(
                -x, kwargs['loc'], kwargs['scale'])
820
821
822
823
824
825
826
        elif kwargs['type'] == 'norm':
            return lambda x: -0.5*((x - kwargs['loc'])**2/kwargs['scale']**2
                                   + np.log(2*np.pi*kwargs['scale']**2))
        else:
            logging.info("kwargs:", kwargs)
            raise ValueError("Print unrecognise distribution")

827
    def _generate_rv(self, **kwargs):
828
829
830
        dist_type = kwargs.pop('type')
        if dist_type == "unif":
            return np.random.uniform(low=kwargs['lower'], high=kwargs['upper'])
Gregory Ashton's avatar
Gregory Ashton committed
831
832
833
        if dist_type == "log10unif":
            return 10**(np.random.uniform(low=kwargs['log10lower'],
                                          high=kwargs['log10upper']))
834
835
836
837
838
        if dist_type == "norm":
            return np.random.normal(loc=kwargs['loc'], scale=kwargs['scale'])
        if dist_type == "halfnorm":
            return np.abs(np.random.normal(loc=kwargs['loc'],
                                           scale=kwargs['scale']))
839
840
841
        if dist_type == "neghalfnorm":
            return -1 * np.abs(np.random.normal(loc=kwargs['loc'],
                                                scale=kwargs['scale']))
842
843
844
845
846
847
        if dist_type == "lognorm":
            return np.random.lognormal(
                mean=kwargs['loc'], sigma=kwargs['scale'])
        else:
            raise ValueError("dist_type {} unknown".format(dist_type))

848
849
850
    def _plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k",
                      temp=0, lw=0.1, nprod=0, add_det_stat_burnin=False,
                      fig=None, axes=None, xoffset=0, plot_det_stat=False,
851
                      context='ggplot', subtractions=None, labelpad=0.05):
852
853
        """ Plot all the chains from a sampler """

854
855
856
857
858
        if context not in plt.style.available:
            raise ValueError((
                'The requested context {} is not available; please select a'
                ' context from `plt.style.available`').format(context))

859
860
861
        if np.ndim(axes) > 1:
            axes = axes.flatten()

862
863
864
865
866
867
868
869
870
871
872
873
874
        shape = sampler.chain.shape
        if len(shape) == 3:
            nwalkers, nsteps, ndim = shape
            chain = sampler.chain[:, :, :]
        if len(shape) == 4:
            ntemps, nwalkers, nsteps, ndim = shape
            if temp < ntemps:
                logging.info("Plotting temperature {} chains".format(temp))
            else:
                raise ValueError(("Requested temperature {} outside of"
                                  "available range").format(temp))
            chain = sampler.chain[temp, :, :, :]

875
876
        if subtractions is None:
            subtractions = [0 for i in range(ndim)]
877
878
879
        else:
            if len(subtractions) != self.ndim:
                raise ValueError('subtractions must be of length ndim')
880

881
882
883
884
        if plot_det_stat:
            extra_subplots = 1
        else:
            extra_subplots = 0
885
        with plt.style.context((context)):
Gregory Ashton's avatar
Gregory Ashton committed
886
            plt.rcParams['text.usetex'] = True
Gregory Ashton's avatar
Gregory Ashton committed
887
            if fig is None and axes is None:
888
                fig = plt.figure(figsize=(4, 3.0*ndim))
889
890
                ax = fig.add_subplot(ndim+extra_subplots, 1, 1)
                axes = [ax] + [fig.add_subplot(ndim+extra_subplots, 1, i)
Gregory Ashton's avatar
Gregory Ashton committed
891
                               for i in range(2, ndim+1)]
892

Gregory Ashton's avatar
Gregory Ashton committed
893
            idxs = np.arange(chain.shape[1])
894
895
896
897
898
            burnin_idx = chain.shape[1] - nprod
            if hasattr(self, 'convergence_idx'):
                convergence_idx = self.convergence_idx
            else:
                convergence_idx = burnin_idx
899
900
            if ndim > 1:
                for i in range(ndim):
901
                    axes[i].ticklabel_format(useOffset=False, axis='y')
Gregory Ashton's avatar
Gregory Ashton committed
902
                    cs = chain[:, :, i].T
903
                    if burnin_idx > 0:
904
905
                        axes[i].plot(xoffset+idxs[:convergence_idx+1],
                                     cs[:convergence_idx+1]-subtractions[i],
906
                                     color="r", alpha=alpha,
Gregory Ashton's avatar
Gregory Ashton committed
907
                                     lw=lw)
908
                        axes[i].axvline(xoffset+convergence_idx,
909
                                        color='k', ls='--', lw=0.25)
910
911
                    axes[i].plot(xoffset+idxs[burnin_idx:],
                                 cs[burnin_idx:]-subtractions[i],
Gregory Ashton's avatar
Gregory Ashton committed
912
                                 color="k", alpha=alpha, lw=lw)
913
                    if symbols:
914
                        if subtractions[i] == 0:
915
                            axes[i].set_ylabel(symbols[i], labelpad=labelpad)
916
917
                        else:
                            axes[i].set_ylabel(
918
919
                                symbols[i]+'$-$'+symbols[i]+'$_0$',
                                labelpad=labelpad)
920

921
922
                    if hasattr(self, 'convergence_diagnostic'):
                        ax = axes[i].twinx()
923
924
                        c_x = np.array(self.convergence_diagnosticx)
                        c_y = np.array(self.convergence_diagnostic)
925
926
927
928
                        break_idx = np.argmin(np.abs(c_x - burnin_idx))
                        ax.plot(c_x[:break_idx], c_y[:break_idx, i], '-b')
                        ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-b')
                        ax.set_ylabel('PSRF')
929
                        ax.ticklabel_format(useOffset=False)
930
                        ax.set_ylim(0.5, self.convergence_plot_upper_lim)
931
            else:
Gregory Ashton's avatar
Gregory Ashton committed
932
                axes[0].ticklabel_format(useOffset=False, axis='y')
Gregory Ashton's avatar
Gregory Ashton committed
933
                cs = chain[:, :, temp].T
Gregory Ashton's avatar
Gregory Ashton committed
934
935
936
937
938
939
                if burnin_idx:
                    axes[0].plot(idxs[:burnin_idx], cs[:burnin_idx],
                                 color="r", alpha=alpha, lw=lw)
                axes[0].plot(idxs[burnin_idx:], cs[burnin_idx:], color="k",
                             alpha=alpha, lw=lw)
                if symbols:
940
                    axes[0].set_ylabel(symbols[0], labelpad=labelpad)
941

Gregory Ashton's avatar
Gregory Ashton committed
942
943
            axes[-1].set_xlabel(r'$\textrm{Number of steps}$', labelpad=0.2)

944
            if plot_det_stat:
945
946
947
                if len(axes) == ndim:
                    axes.append(fig.add_subplot(ndim+1, 1, ndim+1))

948
949
950
                lnl = sampler.lnlikelihood[temp, :, :]
                if burnin_idx and add_det_stat_burnin:
                    burn_in_vals = lnl[:, :burnin_idx].flatten()
951
952
953
954
955
956
957
                    try:
                        axes[-1].hist(burn_in_vals[~np.isnan(burn_in_vals)],
                                      bins=50, histtype='step', color='r')
                    except ValueError:
                        logging.info('Det. Stat. hist failed, most likely all '
                                     'values where the same')
                        pass
958
959
960
                else:
                    burn_in_vals = []
                prod_vals = lnl[:, burnin_idx:].flatten()
961
962
963
964
965
966
967
                try:
                    axes[-1].hist(prod_vals[~np.isnan(prod_vals)], bins=50,
                                  histtype='step', color='k')
                except ValueError:
                    logging.info('Det. Stat. hist failed, most likely all '
                                 'values where the same')
                    pass
968
969
970
971
972
973
974
975
976
977
978
979
                if self.BSGL:
                    axes[-1].set_xlabel(r'$\mathcal{B}_\mathrm{S/GL}$')
                else:
                    axes[-1].set_xlabel(r'$\widetilde{2\mathcal{F}}$')
                axes[-1].set_ylabel(r'$\textrm{Counts}$')
                combined_vals = np.append(burn_in_vals, prod_vals)
                if len(combined_vals) > 0:
                    minv = np.min(combined_vals)
                    maxv = np.max(combined_vals)
                    Range = abs(maxv-minv)
                    axes[-1].set_xlim(minv-0.1*Range, maxv+0.1*Range)

980
                xfmt = matplotlib.ticker.ScalarFormatter()
981
                xfmt.set_powerlimits((-4, 4))
982
983
                axes[-1].xaxis.set_major_formatter(xfmt)

984
985
        return fig, axes

986
    def _apply_corrections_to_p0(self, p0):
Gregory Ashton's avatar
Gregory Ashton committed
987
988
989
        """ Apply any correction to the initial p0 values """
        return p0

990
    def _generate_scattered_p0(self, p):
991
        """ Generate a set of p0s scattered about p """
Gregory Ashton's avatar
Gregory Ashton committed
992
        p0 = [[p + self.scatter_val * p * np.random.randn(self.ndim)
993
994
995
996
               for i in xrange(self.nwalkers)]
              for j in xrange(self.ntemps)]
        return p0

997
    def _generate_initial_p0(self):
998
999
1000
        """ Generate a set of init vals for the walkers """

        if type(self.theta_initial) == dict:
1001
            logging.info('Generate initial values from initial dictionary')
1002
            if hasattr(self, 'nglitch') and self.nglitch > 1:
1003
                raise ValueError('Initial dict not implemented for nglitch>1')
1004
            p0 = [[[self._generate_rv(**self.theta_initial[key])
1005
1006
1007
                    for key in self.theta_keys]
                   for i in range(self.nwalkers)]
                  for j in range(self.ntemps)]
1008
1009
        elif type(self.theta_initial) == list:
            logging.info('Generate initial values from list of theta_initial')
1010
            p0 = [[[self._generate_rv(**val)
1011
1012
1013
                    for val in self.theta_initial]
                   for i in range(self.nwalkers)]
                  for j in range(self.ntemps)]
1014
        elif self.theta_initial is None:
1015
            logging.info('Generate initial values from prior dictionary')
1016
            p0 = [[[self._generate_rv(**self.theta_prior[key])
1017
1018
1019
1020
                    for key in self.theta_keys]
                   for i in range(self.nwalkers)]
                  for j in range(self.ntemps)]
        elif len(self.theta_initial) == self.ndim:
1021
            p0 = self._generate_scattered_p0(self.theta_initial)
1022
1023
1024
1025
1026
        else:
            raise ValueError('theta_initial not understood')

        return p0

1027
    def _get_new_p0(self, sampler):
1028
1029
1030
1031
1032
1033
        """ Returns new initial positions for walkers are burn0 stage

        This returns new positions for all walkers by scattering points about
        the maximum posterior with scale `scatter_val`.

        """
Gregory Ashton's avatar
Gregory Ashton committed
1034
1035
1036
1037
        temp_idx = 0
        pF = sampler.chain[temp_idx, :, :, :]
        lnl = sampler.lnlikelihood[temp_idx, :, :]
        lnp = sampler.lnprobability[temp_idx, :, :]
1038
1039

        # General warnings about the state of lnp
Gregory Ashton's avatar
Gregory Ashton committed
1040
        if np.any(np.isnan(lnp)):
1041
1042
            logging.warning(
                "Of {} lnprobs {} are nan".format(
Gregory Ashton's avatar
Gregory Ashton committed
1043
1044
                    np.shape(lnp), np.sum(np.isnan(lnp))))
        if np.any(np.isposinf(lnp)):
1045
1046
            logging.warning(
                "Of {} lnprobs {} are +np.inf".format(
Gregory Ashton's avatar
Gregory Ashton committed
1047
1048
                    np.shape(lnp), np.sum(np.isposinf(lnp))))
        if np.any(np.isneginf(lnp)):
1049
1050
            logging.warning(
                "Of {} lnprobs {} are -np.inf".format(
Gregory Ashton's avatar
Gregory Ashton committed
1051
                    np.shape(lnp), np.sum(np.isneginf(lnp))))
1052

1053
1054
        lnp_finite = copy.copy(lnp)
        lnp_finite[np.isinf(lnp)] = np.nan
Gregory Ashton's avatar
Gregory Ashton committed
1055
1056
        idx = np.unravel_index(np.nanargmax(lnp_finite), lnp_finite.shape)
        p = pF[idx]
1057
        p0 = self._generate_scattered_p0(p)
1058

1059
1060
1061
1062
1063
1064
1065
1066
        self.search.BSGL = False
        twoF = self.logl(p, self.search)
        self.search.BSGL = self.BSGL

        logging.info(('Gen. new p0 from pos {} which had det. stat.={:2.1f},'
                      ' twoF={:2.1f} and lnp={:2.1f}')
                     .format(idx[1], lnl[idx], twoF, lnp_finite[idx]))

1067
1068
        return