mcmc_based_searches.py 82.2 KB
Newer Older
Gregory Ashton's avatar
Gregory Ashton committed
1
2
""" Searches using MCMC-based methods """

3
import sys
Gregory Ashton's avatar
Gregory Ashton committed
4
import os
5
import copy
Gregory Ashton's avatar
Gregory Ashton committed
6
import logging
7
from collections import OrderedDict
8
9
10
11
12
13
14
15

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import emcee
import corner
import dill as pickle

16
import core
Gregory Ashton's avatar
Gregory Ashton committed
17
from core import tqdm, args, earth_ephem, sun_ephem
18
from optimal_setup_functions import get_V_estimate
Gregory Ashton's avatar
Gregory Ashton committed
19
20
from optimal_setup_functions import get_optimal_setup
import helper_functions
21
22


23
class MCMCSearch(core.BaseSearchClass):
Gregory Ashton's avatar
Gregory Ashton committed
24
    """ MCMC search using ComputeFstat"""
Gregory Ashton's avatar
Gregory Ashton committed
25
    @helper_functions.initializer
Gregory Ashton's avatar
Gregory Ashton committed
26
27
    def __init__(self, label, outdir, theta_prior, tref, minStartTime,
                 maxStartTime, sftfilepath=None, nsteps=[100, 100],
28
29
                 nwalkers=100, ntemps=1, log10temperature_min=-5,
                 theta_initial=None, scatter_val=1e-10,
30
                 binary=False, BSGL=False, minCoverFreq=None,
31
                 maxCoverFreq=None, detectors=None, earth_ephem=None,
32
                 sun_ephem=None, injectSources=None, assumeSqrtSX=None):
33
34
35
36
        """
        Parameters
        label, outdir: str
            A label and directory to read/write data from/to
37
38
        sftfilepath: str
            File patern to match SFTs
39
        theta_prior: dict
40
41
42
43
            Dictionary of priors and fixed values for the search parameters.
            For each parameters (key of the dict), if it is to be held fixed
            the value should be the constant float, if it is be searched, the
            value should be a dictionary of the prior.
44
45
46
47
        theta_initial: dict, array, (None)
            Either a dictionary of distribution about which to distribute the
            initial walkers about, an array (from which the walkers will be
            scattered by scatter_val, or  None in which case the prior is used.
48
        tref, minStartTime, maxStartTime: int
49
50
51
52
53
54
            GPS seconds of the reference time, start time and end time
        nsteps: list (m,)
            List specifying the number of steps to take, the last two entries
            give the nburn and nprod of the 'production' run, all entries
            before are for iterative initialisation steps (usually just one)
            e.g. [1000, 1000, 500].
55
56
57
58
59
60
61
62
        nwalkers, ntemps: int,
            The number of walkers and temperates to use in the parallel
            tempered PTSampler.
        log10temperature_min float < 0
            The  log_10(tmin) value, the set of betas passed to PTSampler are
            generated from np.logspace(0, log10temperature_min, ntemps).
        binary: Bool
            If true, search over binary parameters
63
        detectors: str
64
65
            Two character reference to the data to use, specify None for no
            contraint.
66
67
68
69
70
71
72
73
74
75
        minCoverFreq, maxCoverFreq: float
            Minimum and maximum instantaneous frequency which will be covered
            over the SFT time span as passed to CreateFstatInput
        earth_ephem, sun_ephem: str
            Paths of the two files containing positions of Earth and Sun,
            respectively at evenly spaced times, as passed to CreateFstatInput
            If None defaults defined in BaseSearchClass will be used

        """

Gregory Ashton's avatar
Gregory Ashton committed
76
77
        if os.path.isdir(outdir) is False:
            os.mkdir(outdir)
78
        self.add_log_file()
Gregory Ashton's avatar
Gregory Ashton committed
79
80
        logging.info(
            'Set-up MCMC search for model {} on data {}'.format(
81
                self.label, self.sftfilepath))
82
83
84
        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
        self.unpack_input_theta()
        self.ndim = len(self.theta_keys)
85
86
87
88
        if self.log10temperature_min:
            self.betas = np.logspace(0, self.log10temperature_min, self.ntemps)
        else:
            self.betas = None
89

90
91
92
93
94
95
96
97
        if earth_ephem is None:
            self.earth_ephem = self.earth_ephem_default
        if sun_ephem is None:
            self.sun_ephem = self.sun_ephem_default

        if args.clean and os.path.isfile(self.pickle_path):
            os.rename(self.pickle_path, self.pickle_path+".old")

Gregory Ashton's avatar
Gregory Ashton committed
98
99
100
101
102
        self.symbol_dictionary = dict(
            F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
            delta='$\delta$')
        self.unit_dictionary = dict(
            F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad')
103
        self.rescale_dictionary = {}
Gregory Ashton's avatar
Gregory Ashton committed
104

105
106
107
        self.log_input()

    def log_input(self):
108
        logging.info('theta_prior = {}'.format(self.theta_prior))
109
        logging.info('nwalkers={}'.format(self.nwalkers))
110
111
112
113
        logging.info('scatter_val = {}'.format(self.scatter_val))
        logging.info('nsteps = {}'.format(self.nsteps))
        logging.info('ntemps = {}'.format(self.ntemps))
        logging.info('log10temperature_min = {}'.format(
114
            self.log10temperature_min))
115

Gregory Ashton's avatar
Gregory Ashton committed
116
    def initiate_search_object(self):
117
        logging.info('Setting up search object')
118
        self.search = core.ComputeFstat(
119
120
121
            tref=self.tref, sftfilepath=self.sftfilepath,
            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
            earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
122
            detectors=self.detectors, BSGL=self.BSGL, transient=False,
123
            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
124
125
            binary=self.binary, injectSources=self.injectSources,
            assumeSqrtSX=self.assumeSqrtSX)
126
127

    def logp(self, theta_vals, theta_prior, theta_keys, search):
Gregory Ashton's avatar
Gregory Ashton committed
128
        H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
129
130
131
132
133
134
             zip(theta_vals, theta_keys)]
        return np.sum(H)

    def logl(self, theta, search):
        for j, theta_i in enumerate(self.theta_idxs):
            self.fixed_theta[theta_i] = theta[j]
135
136
        FS = search.compute_fullycoherent_det_stat_single_point(
            *self.fixed_theta)
137
138
139
        return FS

    def unpack_input_theta(self):
140
        full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']
141
142
143
        if self.binary:
            full_theta_keys += [
                'asini', 'period', 'ecc', 'tp', 'argp']
144
145
        full_theta_keys_copy = copy.copy(full_theta_keys)

146
147
        full_theta_symbols = ['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
                              r'$\delta$']
148
149
150
151
        if self.binary:
            full_theta_symbols += [
                'asini', 'period', 'period', 'ecc', 'tp', 'argp']

152
153
        self.theta_keys = []
        fixed_theta_dict = {}
154
        for key, val in self.theta_prior.iteritems():
155
156
            if type(val) is dict:
                fixed_theta_dict[key] = 0
Gregory Ashton's avatar
Gregory Ashton committed
157
                self.theta_keys.append(key)
158
159
160
161
162
163
            elif type(val) in [float, int, np.float64]:
                fixed_theta_dict[key] = val
            else:
                raise ValueError(
                    'Type {} of {} in theta not recognised'.format(
                        type(val), key))
Gregory Ashton's avatar
Gregory Ashton committed
164
            full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

        if len(full_theta_keys_copy) > 0:
            raise ValueError(('Input dictionary `theta` is missing the'
                              'following keys: {}').format(
                                  full_theta_keys_copy))

        self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
        self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
        self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]

        idxs = np.argsort(self.theta_idxs)
        self.theta_idxs = [self.theta_idxs[i] for i in idxs]
        self.theta_symbols = [self.theta_symbols[i] for i in idxs]
        self.theta_keys = [self.theta_keys[i] for i in idxs]

    def check_initial_points(self, p0):
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        for nt in range(self.ntemps):
            logging.info('Checking temperature {} chains'.format(nt))
            initial_priors = np.array([
                self.logp(p, self.theta_prior, self.theta_keys, self.search)
                for p in p0[nt]])
            number_of_initial_out_of_bounds = sum(initial_priors == -np.inf)

            if number_of_initial_out_of_bounds > 0:
                logging.warning(
                    'Of {} initial values, {} are -np.inf due to the prior'
                    .format(len(initial_priors),
                            number_of_initial_out_of_bounds))

                p0 = self.generate_new_p0_to_fix_initial_points(
                    p0, nt, initial_priors)

    def generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors):
        logging.info('Attempting to correct intial values')
        idxs = np.arange(self.nwalkers)[initial_priors == -np.inf]
        count = 0
        while sum(initial_priors == -np.inf) > 0 and count < 100:
            for j in idxs:
                p0[nt][j] = (p0[nt][np.random.randint(0, self.nwalkers)]*(
                             1+np.random.normal(0, 1e-10, self.ndim)))
            initial_priors = np.array([
                self.logp(p, self.theta_prior, self.theta_keys,
                          self.search)
                for p in p0[nt]])
            count += 1

        if sum(initial_priors == -np.inf) > 0:
            logging.info('Failed to fix initial priors')
        else:
            logging.info('Suceeded to fix initial priors')

        return p0
217

218
    def OLD_run_sampler_with_progress_bar(self, sampler, ns, p0):
219
220
        for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
            pass
Gregory Ashton's avatar
Gregory Ashton committed
221
222
        return sampler

223
224
    def setup_convergence_testing(
            self, convergence_period=10, convergence_length=10,
225
            convergence_burnin_fraction=0.25, convergence_threshold_number=10,
226
227
            convergence_threshold=1.2, convergence_prod_threshold=2,
            convergence_plot_upper_lim=2):
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        """
        If called, convergence testing is used during the MCMC simulation

        This uses the Gelmanr-Rubin statistic based on the ratio of between and
        within walkers variance. The original statistic was developed for
        multiple (independent) MCMC simulations, in this context we simply use
        the walkers

        Parameters
        ----------
        convergence_period: int
            period (in number of steps) at which to test convergence
        convergence_length: int
            number of steps to use in testing convergence - this should be
            large enough to measure the variance, but if it is too long
            this will result in incorect early convergence tests
        convergence_burnin_fraction: float [0, 1]
            the fraction of the burn-in period after which to start testing
        convergence_threshold_number: int
            the number of consecutive times where the test passes after which
            to break the burn-in and go to production
        convergence_threshold: float
            the threshold to use in diagnosing convergence. Gelman & Rubin
            recomend a value of 1.2, 1.1 for strict convergence
        convergence_prod_threshold: float
            the threshold to test the production values with
254
255
        convergence_plot_upper_lim: float
            the upper limit to use in the diagnostic plot
256
        """
257
258
259
260
261
262
263

        if convergence_length > convergence_period:
            raise ValueError('convergence_length must be < convergence_period')
        logging.info('Setting up convergence testing')
        self.convergence_length = convergence_length
        self.convergence_period = convergence_period
        self.convergence_burnin_fraction = convergence_burnin_fraction
264
        self.convergence_prod_threshold = convergence_prod_threshold
265
266
267
268
269
        self.convergence_diagnostic = []
        self.convergence_diagnosticx = []
        self.convergence_threshold_number = convergence_threshold_number
        self.convergence_threshold = convergence_threshold
        self.convergence_number = 0
270
        self.convergence_plot_upper_lim = convergence_plot_upper_lim
271

272
    def get_convergence_statistic(self, i, sampler):
273
274
275
276
277
278
279
280
281
        s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
        within_std = np.mean(np.var(s, axis=1), axis=0)
        per_walker_mean = np.mean(s, axis=1)
        mean = np.mean(per_walker_mean, axis=0)
        between_std = np.sqrt(np.mean((per_walker_mean-mean)**2, axis=0))
        W = within_std
        B_over_n = between_std**2 / self.convergence_period
        Vhat = ((self.convergence_period-1.)/self.convergence_period * W
                + B_over_n + B_over_n / float(self.nwalkers))
282
        c = np.sqrt(Vhat/W)
283
        self.convergence_diagnostic.append(c)
284
        self.convergence_diagnosticx.append(i - self.convergence_length/2)
285
286
        return c

287
    def burnin_convergence_test(self, i, sampler, nburn):
288
289
        if i < self.convergence_burnin_fraction*nburn:
            return False
290
        if np.mod(i+1, self.convergence_period) != 0:
291
292
            return False
        c = self.get_convergence_statistic(i, sampler)
293
294
        if np.all(c < self.convergence_threshold):
            self.convergence_number += 1
295
296
        else:
            self.convergence_number = 0
297
298
        return self.convergence_number > self.convergence_threshold_number

299
300
301
302
303
304
    def prod_convergence_test(self, i, sampler, nburn):
        testA = i > nburn + self.convergence_length
        testB = np.mod(i+1, self.convergence_period) == 0
        if testA and testB:
            self.get_convergence_statistic(i, sampler)

305
306
307
308
309
310
311
312
313
    def check_production_convergence(self, k):
        bools = np.any(
            np.array(self.convergence_diagnostic)[k:, :]
            > self.convergence_prod_threshold, axis=1)
        if np.any(bools):
            logging.warning(
                '{} convergence tests in the production run of {} failed'
                .format(np.sum(bools), len(bools)))

314
315
    def run_sampler(self, sampler, p0, nprod=0, nburn=0):
        if hasattr(self, 'convergence_period'):
316
317
318
319
            logging.info('Running {} burn-in steps with convergence testing'
                         .format(nburn))
            iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
            for i, output in enumerate(iterator):
320
                if self.burnin_convergence_test(i, sampler, nburn):
321
322
323
324
325
326
327
328
329
330
331
                    logging.info(
                        'Converged at {} before max number {} of steps reached'
                        .format(i, nburn))
                    self.convergence_idx = i
                    break
            iterator.close()
            logging.info('Running {} production steps'.format(nprod))
            j = nburn
            k = len(self.convergence_diagnostic)
            for result in tqdm(sampler.sample(output[0], iterations=nprod),
                               total=nprod):
332
                self.prod_convergence_test(j, sampler, nburn)
333
334
                j += 1
            self.check_production_convergence(k)
335
336
337
338
339
340
            return sampler
        else:
            for result in tqdm(sampler.sample(p0, iterations=nburn+nprod),
                               total=nburn+nprod):
                pass
            return sampler
341

342
    def run(self, proposal_scale_factor=2, create_plots=True, **kwargs):
343

Gregory Ashton's avatar
Gregory Ashton committed
344
        self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
345
346
347
348
349
350
351
352
353
354
        if self.old_data_is_okay_to_use is True:
            logging.warning('Using saved data from {}'.format(
                self.pickle_path))
            d = self.get_saved_data()
            self.sampler = d['sampler']
            self.samples = d['samples']
            self.lnprobs = d['lnprobs']
            self.lnlikes = d['lnlikes']
            return

Gregory Ashton's avatar
Gregory Ashton committed
355
        self.initiate_search_object()
356
357
358
359

        sampler = emcee.PTSampler(
            self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
            logpargs=(self.theta_prior, self.theta_keys, self.search),
360
            loglargs=(self.search,), betas=self.betas, a=proposal_scale_factor)
361

Gregory Ashton's avatar
Gregory Ashton committed
362
363
        p0 = self.generate_initial_p0()
        p0 = self.apply_corrections_to_p0(p0)
364
365
366
367
368
        self.check_initial_points(p0)

        ninit_steps = len(self.nsteps) - 2
        for j, n in enumerate(self.nsteps[:-2]):
            logging.info('Running {}/{} initialisation with {} steps'.format(
Gregory Ashton's avatar
Gregory Ashton committed
369
                j, ninit_steps, n))
370
            sampler = self.run_sampler(sampler, p0, nburn=n)
371
372
            logging.info("Mean acceptance fraction: {}"
                         .format(np.mean(sampler.acceptance_fraction, axis=1)))
373
374
375
            if self.ntemps > 1:
                logging.info("Tswap acceptance fraction: {}"
                             .format(sampler.tswap_acceptance_fraction))
376
377
378
379
380
381
            if create_plots:
                fig, axes = self.plot_walkers(sampler,
                                              symbols=self.theta_symbols,
                                              **kwargs)
                fig.tight_layout()
                fig.savefig('{}/{}_init_{}_walkers.png'.format(
382
                    self.outdir, self.label, j), dpi=400)
383

384
            p0 = self.get_new_p0(sampler)
Gregory Ashton's avatar
Gregory Ashton committed
385
            p0 = self.apply_corrections_to_p0(p0)
386
387
388
            self.check_initial_points(p0)
            sampler.reset()

Gregory Ashton's avatar
Gregory Ashton committed
389
390
391
392
        if len(self.nsteps) > 1:
            nburn = self.nsteps[-2]
        else:
            nburn = 0
393
394
395
        nprod = self.nsteps[-1]
        logging.info('Running final burn and prod with {} steps'.format(
            nburn+nprod))
396
        sampler = self.run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
397
398
        logging.info("Mean acceptance fraction: {}"
                     .format(np.mean(sampler.acceptance_fraction, axis=1)))
399
400
401
        if self.ntemps > 1:
            logging.info("Tswap acceptance fraction: {}"
                         .format(sampler.tswap_acceptance_fraction))
402

403
404
        if create_plots:
            fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
405
                                          nprod=nprod, **kwargs)
406
407
408
            fig.tight_layout()
            fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
                        dpi=200)
409
410
411
412
413
414
415
416
417
418

        samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
        lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1))
        lnlikes = sampler.lnlikelihood[0, :, nburn:].reshape((-1))
        self.sampler = sampler
        self.samples = samples
        self.lnprobs = lnprobs
        self.lnlikes = lnlikes
        self.save_data(sampler, samples, lnprobs, lnlikes)

419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    def scale_samples(self, samples, symbols, theta_keys):
        for key in theta_keys:
            if key in self.rescale_dictionary:
                idx = theta_keys.index(key)
                s = samples[:, idx]
                if 'subtractor' in self.scale_dictionary[key]:
                    s = self.scale_dictionary[key]['subtractor'] - s
                if 'multipler' in self.scale_dictionary[key]:
                    s *= self.scale_dictionary[key]['multipler']
                samples[:, idx] = s

                if 'label' in self.scale_dictionary['key']:
                    symbols[idx] = self.scale_dictionary[key]['label']

        return samples, symbols

435
    def plot_corner(self, figsize=(7, 7),  tglitch_ratio=False,
436
437
438
                    add_prior=False, nstds=None, label_offset=0.4,
                    dpi=300, rc_context={}, **kwargs):

Gregory Ashton's avatar
Gregory Ashton committed
439
440
441
442
443
444
445
446
447
448
        if self.ndim < 2:
            with plt.rc_context(rc_context):
                fig, ax = plt.subplots(figsize=figsize)
                ax.hist(self.samples, bins=50, histtype='stepfilled')
                ax.set_xlabel(self.theta_symbols[0])

            fig.savefig('{}/{}_corner.png'.format(
                self.outdir, self.label), dpi=dpi)
            return

449
450
451
452
453
454
        with plt.rc_context(rc_context):
            fig, axes = plt.subplots(self.ndim, self.ndim,
                                     figsize=figsize)

            samples_plt = copy.copy(self.samples)
            theta_symbols_plt = copy.copy(self.theta_symbols)
455
456
457

            samples_plt, theta_symbols_plt = self.scale_samples(
                samples_plt, theta_symbols_plt, self.theta_keys)
458
459
            theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}')
                                 for s in theta_symbols_plt]
460
461
462
463
464

            if tglitch_ratio:
                for j, k in enumerate(self.theta_keys):
                    if k == 'tglitch':
                        s = samples_plt[:, j]
465
466
467
                        samples_plt[:, j] = (
                            s - self.minStartTime)/(
                                self.maxStartTime - self.minStartTime)
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
                        theta_symbols_plt[j] = r'$R_{\textrm{glitch}}$'

            if type(nstds) is int and 'range' not in kwargs:
                _range = []
                for j, s in enumerate(samples_plt.T):
                    median = np.median(s)
                    std = np.std(s)
                    _range.append((median - nstds*std, median + nstds*std))
            else:
                _range = None

            fig_triangle = corner.corner(samples_plt,
                                         labels=theta_symbols_plt,
                                         fig=fig,
                                         bins=50,
                                         max_n_ticks=4,
                                         plot_contours=True,
                                         plot_datapoints=True,
                                         label_kwargs={'fontsize': 8},
                                         data_kwargs={'alpha': 0.1,
                                                      'ms': 0.5},
                                         range=_range,
                                         **kwargs)

            axes_list = fig_triangle.get_axes()
            axes = np.array(axes_list).reshape(self.ndim, self.ndim)
            plt.draw()
            for ax in axes[:, 0]:
                ax.yaxis.set_label_coords(-label_offset, 0.5)
            for ax in axes[-1, :]:
                ax.xaxis.set_label_coords(0.5, -label_offset)
            for ax in axes_list:
                ax.set_rasterized(True)
                ax.set_rasterization_zorder(-10)
            plt.tight_layout(h_pad=0.0, w_pad=0.0)
            fig.subplots_adjust(hspace=0.05, wspace=0.05)

            if add_prior:
                self.add_prior_to_corner(axes, samples_plt)

            fig_triangle.savefig('{}/{}_corner.png'.format(
                self.outdir, self.label), dpi=dpi)
510
511
512
513
514
515

    def add_prior_to_corner(self, axes, samples):
        for i, key in enumerate(self.theta_keys):
            ax = axes[i][i]
            xlim = ax.get_xlim()
            s = samples[:, i]
Gregory Ashton's avatar
Gregory Ashton committed
516
            prior = self.generic_lnprior(**self.theta_prior[key])
517
518
519
520
521
522
            x = np.linspace(s.min(), s.max(), 100)
            ax2 = ax.twinx()
            ax2.get_yaxis().set_visible(False)
            ax2.plot(x, [prior(xi) for xi in x], '-r')
            ax.set_xlim(xlim)

523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
    def plot_prior_posterior(self, normal_stds=2):
        """ Plot the posterior in the context of the prior """
        fig, axes = plt.subplots(nrows=self.ndim, figsize=(8, 4*self.ndim))
        N = 1000
        from scipy.stats import gaussian_kde

        for i, (ax, key) in enumerate(zip(axes, self.theta_keys)):
            prior_dict = self.theta_prior[key]
            prior_func = self.generic_lnprior(**prior_dict)
            if prior_dict['type'] == 'unif':
                x = np.linspace(prior_dict['lower'], prior_dict['upper'], N)
                prior = prior_func(x)
                prior[0] = 0
                prior[-1] = 0
            elif prior_dict['type'] == 'norm':
                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = prior_func(x)
542
543
544
545
546
            elif prior_dict['type'] == 'halfnorm':
                lower = prior_dict['loc']
                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = [prior_func(xi) for xi in x]
Gregory Ashton's avatar
Gregory Ashton committed
547
548
549
550
551
            elif prior_dict['type'] == 'neghalfnorm':
                upper = prior_dict['loc']
                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
                x = np.linspace(lower, upper, N)
                prior = [prior_func(xi) for xi in x]
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
            else:
                raise ValueError('Not implemented for prior type {}'.format(
                    prior_dict['type']))
            priorln = ax.plot(x, prior, 'r', label='prior')
            ax.set_xlabel(self.theta_symbols[i])

            s = self.samples[:, i]
            while len(s) > 10**4:
                # random downsample to avoid slow calculation of kde
                s = np.random.choice(s, size=int(len(s)/2.))
            kde = gaussian_kde(s)
            ax2 = ax.twinx()
            postln = ax2.plot(x, kde.pdf(x), 'k', label='posterior')
            ax2.set_yticklabels([])
            ax.set_yticklabels([])

        lns = priorln + postln
        labs = [l.get_label() for l in lns]
        axes[0].legend(lns, labs, loc=1, framealpha=0.8)

        fig.savefig('{}/{}_prior_posterior.png'.format(
            self.outdir, self.label))

575
    def plot_cumulative_max(self, **kwargs):
Gregory Ashton's avatar
Gregory Ashton committed
576
577
578
579
        d, maxtwoF = self.get_max_twoF()
        for key, val in self.theta_prior.iteritems():
            if key not in d:
                d[key] = val
580
581

        if hasattr(self, 'search') is False:
Gregory Ashton's avatar
Gregory Ashton committed
582
            self.initiate_search_object()
583
584
585
        if self.binary is False:
            self.search.plot_twoF_cumulative(
                self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'],
586
                Alpha=d['Alpha'], Delta=d['Delta'],
587
                tstart=self.minStartTime, tend=self.maxStartTime,
588
                **kwargs)
589
590
591
592
593
        else:
            self.search.plot_twoF_cumulative(
                self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'],
                Alpha=d['Alpha'], Delta=d['Delta'], asini=d['asini'],
                period=d['period'], ecc=d['ecc'], argp=d['argp'], tp=d['argp'],
594
                tstart=self.minStartTime, tend=self.maxStartTime, **kwargs)
Gregory Ashton's avatar
Gregory Ashton committed
595

Gregory Ashton's avatar
Gregory Ashton committed
596
    def generic_lnprior(self, **kwargs):
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
        """ Return a lambda function of the pdf

        Parameters
        ----------
        kwargs: dict
            A dictionary containing 'type' of pdf and shape parameters

        """

        def logunif(x, a, b):
            above = x < b
            below = x > a
            if type(above) is not np.ndarray:
                if above and below:
                    return -np.log(b-a)
                else:
                    return -np.inf
            else:
                idxs = np.array([all(tup) for tup in zip(above, below)])
                p = np.zeros(len(x)) - np.inf
                p[idxs] = -np.log(b-a)
                return p

        def halfnorm(x, loc, scale):
621
            if x < loc:
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
                return -np.inf
            else:
                return -0.5*((x-loc)**2/scale**2+np.log(0.5*np.pi*scale**2))

        def cauchy(x, x0, gamma):
            return 1.0/(np.pi*gamma*(1+((x-x0)/gamma)**2))

        def exp(x, x0, gamma):
            if x > x0:
                return np.log(gamma) - gamma*(x - x0)
            else:
                return -np.inf

        if kwargs['type'] == 'unif':
            return lambda x: logunif(x, kwargs['lower'], kwargs['upper'])
        elif kwargs['type'] == 'halfnorm':
            return lambda x: halfnorm(x, kwargs['loc'], kwargs['scale'])
639
640
        elif kwargs['type'] == 'neghalfnorm':
            return lambda x: halfnorm(-x, kwargs['loc'], kwargs['scale'])
641
642
643
644
645
646
647
        elif kwargs['type'] == 'norm':
            return lambda x: -0.5*((x - kwargs['loc'])**2/kwargs['scale']**2
                                   + np.log(2*np.pi*kwargs['scale']**2))
        else:
            logging.info("kwargs:", kwargs)
            raise ValueError("Print unrecognise distribution")

Gregory Ashton's avatar
Gregory Ashton committed
648
    def generate_rv(self, **kwargs):
649
650
651
652
653
654
655
656
        dist_type = kwargs.pop('type')
        if dist_type == "unif":
            return np.random.uniform(low=kwargs['lower'], high=kwargs['upper'])
        if dist_type == "norm":
            return np.random.normal(loc=kwargs['loc'], scale=kwargs['scale'])
        if dist_type == "halfnorm":
            return np.abs(np.random.normal(loc=kwargs['loc'],
                                           scale=kwargs['scale']))
657
658
659
        if dist_type == "neghalfnorm":
            return -1 * np.abs(np.random.normal(loc=kwargs['loc'],
                                                scale=kwargs['scale']))
660
661
662
663
664
665
        if dist_type == "lognorm":
            return np.random.lognormal(
                mean=kwargs['loc'], sigma=kwargs['scale'])
        else:
            raise ValueError("dist_type {} unknown".format(dist_type))

Gregory Ashton's avatar
Gregory Ashton committed
666
    def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
667
                     lw=0.1, nprod=0, add_det_stat_burnin=False,
668
                     fig=None, axes=None, xoffset=0, plot_det_stat=True,
669
                     context='classic', subtractions=None, labelpad=0.05):
670
671
        """ Plot all the chains from a sampler """

672
673
674
        if np.ndim(axes) > 1:
            axes = axes.flatten()

675
676
677
678
679
680
681
682
683
684
685
686
687
        shape = sampler.chain.shape
        if len(shape) == 3:
            nwalkers, nsteps, ndim = shape
            chain = sampler.chain[:, :, :]
        if len(shape) == 4:
            ntemps, nwalkers, nsteps, ndim = shape
            if temp < ntemps:
                logging.info("Plotting temperature {} chains".format(temp))
            else:
                raise ValueError(("Requested temperature {} outside of"
                                  "available range").format(temp))
            chain = sampler.chain[temp, :, :, :]

688
689
        if subtractions is None:
            subtractions = [0 for i in range(ndim)]
690
691
692
        else:
            if len(subtractions) != self.ndim:
                raise ValueError('subtractions must be of length ndim')
693

694
        with plt.style.context((context)):
Gregory Ashton's avatar
Gregory Ashton committed
695
            plt.rcParams['text.usetex'] = True
Gregory Ashton's avatar
Gregory Ashton committed
696
            if fig is None and axes is None:
697
                fig = plt.figure(figsize=(4, 3.0*ndim))
Gregory Ashton's avatar
Gregory Ashton committed
698
                ax = fig.add_subplot(ndim+1, 1, 1)
Gregory Ashton's avatar
Gregory Ashton committed
699
                axes = [ax] + [fig.add_subplot(ndim+1, 1, i)
Gregory Ashton's avatar
Gregory Ashton committed
700
                               for i in range(2, ndim+1)]
701

Gregory Ashton's avatar
Gregory Ashton committed
702
            idxs = np.arange(chain.shape[1])
703
704
705
706
707
            burnin_idx = chain.shape[1] - nprod
            if hasattr(self, 'convergence_idx'):
                convergence_idx = self.convergence_idx
            else:
                convergence_idx = burnin_idx
708
709
            if ndim > 1:
                for i in range(ndim):
710
                    axes[i].ticklabel_format(useOffset=False, axis='y')
Gregory Ashton's avatar
Gregory Ashton committed
711
                    cs = chain[:, :, i].T
712
713
714
                    if burnin_idx > 0:
                        axes[i].plot(xoffset+idxs[:convergence_idx],
                                     cs[:convergence_idx]-subtractions[i],
715
                                     color="r", alpha=alpha,
Gregory Ashton's avatar
Gregory Ashton committed
716
                                     lw=lw)
717
718
                    axes[i].plot(xoffset+idxs[burnin_idx:],
                                 cs[burnin_idx:]-subtractions[i],
Gregory Ashton's avatar
Gregory Ashton committed
719
                                 color="k", alpha=alpha, lw=lw)
720
                    if symbols:
721
                        if subtractions[i] == 0:
722
                            axes[i].set_ylabel(symbols[i], labelpad=labelpad)
723
724
                        else:
                            axes[i].set_ylabel(
725
726
                                symbols[i]+'$-$'+symbols[i]+'$_0$',
                                labelpad=labelpad)
727

728
729
                    if hasattr(self, 'convergence_diagnostic'):
                        ax = axes[i].twinx()
730
731
                        c_x = np.array(self.convergence_diagnosticx)
                        c_y = np.array(self.convergence_diagnostic)
732
733
734
735
                        break_idx = np.argmin(np.abs(c_x - burnin_idx))
                        ax.plot(c_x[:break_idx], c_y[:break_idx, i], '-b')
                        ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-b')
                        ax.set_ylabel('PSRF')
736
                        ax.ticklabel_format(useOffset=False)
737
                        ax.set_ylim(1, self.convergence_plot_upper_lim)
738
            else:
Gregory Ashton's avatar
Gregory Ashton committed
739
                axes[0].ticklabel_format(useOffset=False, axis='y')
Gregory Ashton's avatar
Gregory Ashton committed
740
                cs = chain[:, :, temp].T
Gregory Ashton's avatar
Gregory Ashton committed
741
742
743
744
745
746
                if burnin_idx:
                    axes[0].plot(idxs[:burnin_idx], cs[:burnin_idx],
                                 color="r", alpha=alpha, lw=lw)
                axes[0].plot(idxs[burnin_idx:], cs[burnin_idx:], color="k",
                             alpha=alpha, lw=lw)
                if symbols:
747
                    axes[0].set_ylabel(symbols[0], labelpad=labelpad)
748

749
            if plot_det_stat:
750
751
752
                if len(axes) == ndim:
                    axes.append(fig.add_subplot(ndim+1, 1, ndim+1))

753
754
755
                lnl = sampler.lnlikelihood[temp, :, :]
                if burnin_idx and add_det_stat_burnin:
                    burn_in_vals = lnl[:, :burnin_idx].flatten()
756
757
758
759
760
761
762
                    try:
                        axes[-1].hist(burn_in_vals[~np.isnan(burn_in_vals)],
                                      bins=50, histtype='step', color='r')
                    except ValueError:
                        logging.info('Det. Stat. hist failed, most likely all '
                                     'values where the same')
                        pass
763
764
765
                else:
                    burn_in_vals = []
                prod_vals = lnl[:, burnin_idx:].flatten()
766
767
768
769
770
771
772
                try:
                    axes[-1].hist(prod_vals[~np.isnan(prod_vals)], bins=50,
                                  histtype='step', color='k')
                except ValueError:
                    logging.info('Det. Stat. hist failed, most likely all '
                                 'values where the same')
                    pass
773
774
775
776
777
778
779
780
781
782
783
784
                if self.BSGL:
                    axes[-1].set_xlabel(r'$\mathcal{B}_\mathrm{S/GL}$')
                else:
                    axes[-1].set_xlabel(r'$\widetilde{2\mathcal{F}}$')
                axes[-1].set_ylabel(r'$\textrm{Counts}$')
                combined_vals = np.append(burn_in_vals, prod_vals)
                if len(combined_vals) > 0:
                    minv = np.min(combined_vals)
                    maxv = np.max(combined_vals)
                    Range = abs(maxv-minv)
                    axes[-1].set_xlim(minv-0.1*Range, maxv+0.1*Range)

785
                xfmt = matplotlib.ticker.ScalarFormatter()
786
                xfmt.set_powerlimits((-4, 4))
787
788
                axes[-1].xaxis.set_major_formatter(xfmt)

789
            axes[-2].set_xlabel(r'$\textrm{Number of steps}$', labelpad=0.2)
790
791
        return fig, axes

Gregory Ashton's avatar
Gregory Ashton committed
792
793
794
795
796
    def apply_corrections_to_p0(self, p0):
        """ Apply any correction to the initial p0 values """
        return p0

    def generate_scattered_p0(self, p):
797
        """ Generate a set of p0s scattered about p """
Gregory Ashton's avatar
Gregory Ashton committed
798
        p0 = [[p + self.scatter_val * p * np.random.randn(self.ndim)
799
800
801
802
               for i in xrange(self.nwalkers)]
              for j in xrange(self.ntemps)]
        return p0

Gregory Ashton's avatar
Gregory Ashton committed
803
    def generate_initial_p0(self):
804
805
806
        """ Generate a set of init vals for the walkers """

        if type(self.theta_initial) == dict:
807
            logging.info('Generate initial values from initial dictionary')
808
            if hasattr(self, 'nglitch') and self.nglitch > 1:
809
                raise ValueError('Initial dict not implemented for nglitch>1')
Gregory Ashton's avatar
Gregory Ashton committed
810
            p0 = [[[self.generate_rv(**self.theta_initial[key])
811
812
813
                    for key in self.theta_keys]
                   for i in range(self.nwalkers)]
                  for j in range(self.ntemps)]
814
815
816
817
818
819
        elif type(self.theta_initial) == list:
            logging.info('Generate initial values from list of theta_initial')
            p0 = [[[self.generate_rv(**val)
                    for val in self.theta_initial]
                   for i in range(self.nwalkers)]
                  for j in range(self.ntemps)]
820
        elif self.theta_initial is None:
821
            logging.info('Generate initial values from prior dictionary')
Gregory Ashton's avatar
Gregory Ashton committed
822
            p0 = [[[self.generate_rv(**self.theta_prior[key])
823
824
825
826
                    for key in self.theta_keys]
                   for i in range(self.nwalkers)]
                  for j in range(self.ntemps)]
        elif len(self.theta_initial) == self.ndim:
Gregory Ashton's avatar
Gregory Ashton committed
827
            p0 = self.generate_scattered_p0(self.theta_initial)
828
829
830
831
832
        else:
            raise ValueError('theta_initial not understood')

        return p0

833
    def get_new_p0(self, sampler):
834
835
836
837
838
839
        """ Returns new initial positions for walkers are burn0 stage

        This returns new positions for all walkers by scattering points about
        the maximum posterior with scale `scatter_val`.

        """
Gregory Ashton's avatar
Gregory Ashton committed
840
841
842
843
        temp_idx = 0
        pF = sampler.chain[temp_idx, :, :, :]
        lnl = sampler.lnlikelihood[temp_idx, :, :]
        lnp = sampler.lnprobability[temp_idx, :, :]
844
845

        # General warnings about the state of lnp
Gregory Ashton's avatar
Gregory Ashton committed
846
        if np.any(np.isnan(lnp)):
847
848
            logging.warning(
                "Of {} lnprobs {} are nan".format(
Gregory Ashton's avatar
Gregory Ashton committed
849
850
                    np.shape(lnp), np.sum(np.isnan(lnp))))
        if np.any(np.isposinf(lnp)):
851
852
            logging.warning(
                "Of {} lnprobs {} are +np.inf".format(
Gregory Ashton's avatar
Gregory Ashton committed
853
854
                    np.shape(lnp), np.sum(np.isposinf(lnp))))
        if np.any(np.isneginf(lnp)):
855
856
            logging.warning(
                "Of {} lnprobs {} are -np.inf".format(
Gregory Ashton's avatar
Gregory Ashton committed
857
                    np.shape(lnp), np.sum(np.isneginf(lnp))))
858

859
860
        lnp_finite = copy.copy(lnp)
        lnp_finite[np.isinf(lnp)] = np.nan
Gregory Ashton's avatar
Gregory Ashton committed
861
862
        idx = np.unravel_index(np.nanargmax(lnp_finite), lnp_finite.shape)
        p = pF[idx]
863
        p0 = self.generate_scattered_p0(p)
864

865
866
867
868
869
870
871
872
        self.search.BSGL = False
        twoF = self.logl(p, self.search)
        self.search.BSGL = self.BSGL

        logging.info(('Gen. new p0 from pos {} which had det. stat.={:2.1f},'
                      ' twoF={:2.1f} and lnp={:2.1f}')
                     .format(idx[1], lnl[idx], twoF, lnp_finite[idx]))

873
874
875
876
877
        return p0

    def get_save_data_dictionary(self):
        d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
                 ntemps=self.ntemps, theta_keys=self.theta_keys,
Gregory Ashton's avatar
Gregory Ashton committed
878
                 theta_prior=self.theta_prior, scatter_val=self.scatter_val,
879
                 log10temperature_min=self.log10temperature_min,
880
                 BSGL=self.BSGL)
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
        return d

    def save_data(self, sampler, samples, lnprobs, lnlikes):
        d = self.get_save_data_dictionary()
        d['sampler'] = sampler
        d['samples'] = samples
        d['lnprobs'] = lnprobs
        d['lnlikes'] = lnlikes

        if os.path.isfile(self.pickle_path):
            logging.info('Saving backup of {} as {}.old'.format(
                self.pickle_path, self.pickle_path))
            os.rename(self.pickle_path, self.pickle_path+".old")
        with open(self.pickle_path, "wb") as File:
            pickle.dump(d, File)

    def get_saved_data(self):
        with open(self.pickle_path, "r") as File:
            d = pickle.load(File)
        return d

    def check_old_data_is_okay_to_use(self):
903
904
905
906
        if args.use_old_data:
            logging.info("Forcing use of old data")
            return True

907
908
909
910
        if os.path.isfile(self.pickle_path) is False:
            logging.info('No pickled data found')
            return False

Gregory Ashton's avatar
Gregory Ashton committed
911
912
913
914
915
916
        if self.sftfilepath is not None:
            oldest_sft = min([os.path.getmtime(f) for f in
                              self.get_list_of_matching_sfts()])
            if os.path.getmtime(self.pickle_path) < oldest_sft:
                logging.info('Pickled data outdates sft files')
                return False
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931

        old_d = self.get_saved_data().copy()
        new_d = self.get_save_data_dictionary().copy()

        old_d.pop('samples')
        old_d.pop('sampler')
        old_d.pop('lnprobs')
        old_d.pop('lnlikes')

        mod_keys = []
        for key in new_d.keys():
            if key in old_d:
                if new_d[key] != old_d[key]:
                    mod_keys.append((key, old_d[key], new_d[key]))
            else:
932
                raise ValueError('Keys {} not in old dictionary'.format(key))
933
934
935
936
937
938
939
940
941

        if len(mod_keys) == 0:
            return True
        else:
            logging.warning("Saved data differs from requested")
            logging.info("Differences found in following keys:")
            for key in mod_keys:
                if len(key) == 3:
                    if np.isscalar(key[1]) or key[0] == 'nsteps':
942
                        logging.info("    {} : {} -> {}".format(*key))
943
                    else:
944
                        logging.info("    " + key[0])
945
946
947
948
949
                else:
                    logging.info(key)
            return False

    def get_max_twoF(self, threshold=0.05):
950
        """ Returns the max likelihood sample and the corresponding 2F value
951
952
953
954
955
956
957
958
959
960
961
962
963
964

        Note: the sample is returned as a dictionary along with an estimate of
        the standard deviation calculated from the std of all samples with a
        twoF within `threshold` (relative) to the max twoF

        """
        if any(np.isposinf(self.lnlikes)):
            logging.info('twoF values contain positive infinite values')
        if any(np.isneginf(self.lnlikes)):
            logging.info('twoF values contain negative infinite values')
        if any(np.isnan(self.lnlikes)):
            logging.info('twoF values contain nan')
        idxs = np.isfinite(self.lnlikes)
        jmax = np.nanargmax(self.lnlikes[idxs])
965
        maxlogl = self.lnlikes[jmax]
966
        d = OrderedDict()
967

968
969
        if self.BSGL:
            if hasattr(self, 'search') is False:
Gregory Ashton's avatar
Gregory Ashton committed
970
                self.initiate_search_object()
971
972
973
974
975
976
977
            p = self.samples[jmax]
            self.search.BSGL = False
            maxtwoF = self.logl(p, self.search)
            self.search.BSGL = self.BSGL
        else:
            maxtwoF = maxlogl

Gregory Ashton's avatar
Gregory Ashton committed
978
        repeats = []
979
        for i, k in enumerate(self.theta_keys):
Gregory Ashton's avatar
Gregory Ashton committed
980
981
982
983
984
985
986
987
988
989
            if k in d and k not in repeats:
                d[k+'_0'] = d[k]  # relabel the old key
                d.pop(k)
                repeats.append(k)
            if k in repeats:
                k = k + '_0'
                count = 1
                while k in d:
                    k = k.replace('_{}'.format(count-1), '_{}'.format(count))
                    count += 1
990
991
992
993
994
            d[k] = self.samples[jmax][i]
        return d, maxtwoF

    def get_median_stds(self):
        """ Returns a dict of the median and std of all production samples """
995
        d = OrderedDict()
Gregory Ashton's avatar
Gregory Ashton committed
996
        repeats = []
997
        for s, k in zip(self.samples.T, self.theta_keys):
Gregory Ashton's avatar
Gregory Ashton committed
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
            if k in d and k not in repeats:
                d[k+'_0'] = d[k]  # relabel the old key
                d[k+'_0_std'] = d[k+'_std']
                d.pop(k)
                d.pop(k+'_std')
                repeats.append(k)
            if k in repeats:
                k = k + '_0'
                count = 1
                while k in d:
                    k = k.replace('_{}'.format(count-1), '_{}'.format(count))
                    count += 1

1011
1012
1013
1014
            d[k] = np.median(s)
            d[k+'_std'] = np.std(s)
        return d

1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
    def check_if_samples_are_railing(self, threshold=0.01):
        return_flag = False
        for s, k in zip(self.samples.T, self.theta_keys):
            prior = self.theta_prior[k]
            if prior['type'] == 'unif':
                prior_range = prior['upper'] - prior['lower']
                edges = []
                fracs = []
                for l in ['lower', 'upper']:
                    bools = np.abs(s - prior[l])/prior_range < threshold
                    if np.any(bools):
                        edges.append(l)
                        fracs.append(str(100*float(np.sum(bools))/len(bools)))
                if len(edges) > 0:
                    logging.warning(
                        '{}% of the {} posterior is railing on the {} edges'
                        .format('% & '.join(fracs), k, ' & '.join(edges)))
                    return_flag = True
        return return_flag

1035
1036
1037
1038
    def write_par(self, method='med'):
        """ Writes a .par of the best-fit params with an estimated std """
        logging.info('Writing {}/{}.par using the {} method'.format(
            self.outdir, self.label, method))
1039
1040
1041
1042

        median_std_d = self.get_median_stds()
        max_twoF_d, max_twoF = self.get_max_twoF()

Gregory Ashton's avatar
Gregory Ashton committed
1043
        logging.info('Writing par file with max twoF = {}'.format(max_twoF))
1044
1045
1046
        filename = '{}/{}.par'.format(self.outdir, self.label)
        with open(filename, 'w+') as f:
            f.write('MaxtwoF = {}\n'.format(max_twoF))
Gregory Ashton's avatar
Gregory Ashton committed
1047
            f.write('tref = {}\n'.format(self.tref))
1048
1049
            if hasattr(self, 'theta0_index'):
                f.write('theta0_index = {}\n'.format(self.theta0_idx))
1050
            if method == 'med':
1051
1052
                for key, val in median_std_d.iteritems():
                    f.write('{} = {:1.16e}\n'.format(key, val))
1053
            if method == 'twoFmax':
1054
1055
1056
                for key, val in max_twoF_d.iteritems():
                    f.write('{} = {:1.16e}\n'.format(key, val))

Gregory Ashton's avatar
Gregory Ashton committed
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
    def write_prior_table(self):
        with open('{}/{}_prior.tex'.format(self.outdir, self.label), 'w') as f:
            f.write(r"\begin{tabular}{c l c} \hline" + '\n'
                    r"Parameter & & &  \\ \hhline{====}")

            for key, prior in self.theta_prior.iteritems():
                if type(prior) is dict:
                    Type = prior['type']
                    if Type == "unif":
                        a = prior['lower']
                        b = prior['upper']
                        line = r"{} & $\mathrm{{Unif}}$({}, {}) & {}\\"
                    elif Type == "norm":
                        a = prior['loc']
                        b = prior['scale']
                        line = r"{} & $\mathcal{{N}}$({}, {}) & {}\\"
                    elif Type == "halfnorm":
                        a = prior['loc']
                        b = prior['scale']
                        line = r"{} & $|\mathcal{{N}}$({}, {})| & {}\\"

                    u = self.unit_dictionary[key]
                    s = self.symbol_dictionary[key]
                    f.write("\n")
                    a = helper_functions.texify_float(a)
                    b = helper_functions.texify_float(b)
                    f.write(" " + line.format(s, a, b, u) + r" \\")
            f.write("\n\end{tabular}\n")

1086
    def print_summary(self):
Gregory Ashton's avatar
Gregory Ashton committed
1087
        max_twoFd, max_twoF = self.get_max_twoF()
1088
        median_std_d = self.get_median_stds()
Gregory Ashton's avatar
Gregory Ashton committed
1089
        logging.info('Summary:')
1090
        if hasattr(self, 'theta0_idx'):
Gregory Ashton's avatar
Gregory Ashton committed
1091
1092
            logging.info('theta0 index: {}'.format(self.theta0_idx))
        logging.info('Max twoF: {} with parameters:'.format(max_twoF))
Gregory Ashton's avatar
Gregory Ashton committed
1093
1094
        for k in np.sort(max_twoFd.keys()):
            print('  {:10s} = {:1.9e}'.format(k, max_twoFd[k]))
Gregory Ashton's avatar
Gregory Ashton committed
1095
        logging.info('Median +/- std for production values')
1096
        for k in np.sort(median_std_d.keys()):
1097
            if 'std' not in k:
Gregory Ashton's avatar
Gregory Ashton committed
1098
                logging.info('  {:10s} = {:1.9e} +/- {:1.9e}'.format(
1099
                    k, median_std_d[k], median_std_d[k+'_std']))
Gregory Ashton's avatar
Gregory Ashton committed
1100
        logging.info('\n')
1101

1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
    def CF_twoFmax(self, theta, twoFmax, ntrials):
        Fmax = twoFmax/2.0
        return (np.exp(1j*theta*twoFmax)*ntrials/2.0
                * Fmax*np.exp(-Fmax)*(1-(1+Fmax)*np.exp(-Fmax))**(ntrials-1))

    def pdf_twoFhat(self, twoFhat, nglitch, ntrials, twoFmax=100, dtwoF=0.1):
        if np.ndim(ntrials) == 0:
            ntrials = np.zeros(nglitch+1) + ntrials
        twoFmax_int = np.arange(0, twoFmax, dtwoF)
        theta_int = np.arange(-1/dtwoF, 1./dtwoF, 1./twoFmax)
        CF_twoFmax_theta = np.array(
            [[np.trapz(self.CF_twoFmax(t, twoFmax_int, ntrial), twoFmax_int)
              for t in theta_int]
             for ntrial in ntrials])
        CF_twoFhat_theta = np.prod(CF_twoFmax_theta, axis=0)
        pdf = (1/(2*np.pi)) * np.array(
            [np.trapz(np.exp(-1j*theta_int*twoFhat_val)
             * CF_twoFhat_theta, theta_int) for twoFhat_val in twoFhat])
        return pdf.real

    def p_val_twoFhat(self, twoFhat, ntrials, twoFhatmax=500, Npoints=1000):
1123
        """ Caluculate the p-value for the given twoFhat in Gaussian noise
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141

        Parameters
        ----------
        twoFhat: float
            The observed twoFhat value
        ntrials: int, array of len Nglitch+1
            The number of trials for each glitch+1
        """
        twoFhats = np.linspace(twoFhat, twoFhatmax, Npoints)
        pdf = self.pdf_twoFhat(twoFhats, self.nglitch, ntrials)
        return np.trapz(pdf, twoFhats)

    def get_p_value(self, delta_F0, time_trials=0):
        """ Get's the p-value for the maximum twoFhat value """
        d, max_twoF = self.get_max_twoF()
        if self.nglitch == 1:
            tglitches = [d['tglitch']]
        else:
1142
1143
            tglitches = [d['tglitch_{}'.format(i)]
                         for i in range(self.nglitch)]
1144
        tboundaries = [self.minStartTime] + tglitches + [self.maxStartTime]
1145
        deltaTs = np.diff(tboundaries)
1146
1147
        ntrials = [time_trials + delta_F0 * dT for dT in deltaTs]
        p_val = self.p_val_twoFhat(max_twoF, ntrials)
1148
        print('p-value = {}'.format(p_val))
1149
1150
        return p_val

1151
    def get_evidence(self):
1152
1153
1154
1155
1156
1157
        fburnin = float(self.nsteps[-2])/np.sum(self.nsteps[-2:])
        lnev, lnev_err = self.sampler.thermodynamic_integration_log_evidence(
            fburnin=fburnin)

        log10evidence = lnev/np.log(10)
        log10evidence_err = lnev_err/np.log(10)
1158
1159
1160
1161
        return log10evidence, log10evi