mcmc_based_searches.py 87.2 KB
Newer Older
1001
1002
1003
1004
1005
1006
1007
        if os.path.isfile(self.pickle_path):
            logging.info('Saving backup of {} as {}.old'.format(
                self.pickle_path, self.pickle_path))
            os.rename(self.pickle_path, self.pickle_path+".old")
        with open(self.pickle_path, "wb") as File:
            pickle.dump(d, File)

1008
1009
    def get_saved_data_dictionary(self):
        """ Returns dictionary of the data saved as pickle """
1010
1011
1012
1013
        with open(self.pickle_path, "r") as File:
            d = pickle.load(File)
        return d

1014
    def _check_old_data_is_okay_to_use(self):
1015
1016
1017
1018
        if args.use_old_data:
            logging.info("Forcing use of old data")
            return True

1019
1020
1021
1022
        if os.path.isfile(self.pickle_path) is False:
            logging.info('No pickled data found')
            return False

Gregory Ashton's avatar
Gregory Ashton committed
1023
1024
        if self.sftfilepath is not None:
            oldest_sft = min([os.path.getmtime(f) for f in
1025
                              self._get_list_of_matching_sfts()])
Gregory Ashton's avatar
Gregory Ashton committed
1026
1027
1028
            if os.path.getmtime(self.pickle_path) < oldest_sft:
                logging.info('Pickled data outdates sft files')
                return False
1029

1030
1031
        old_d = self.get_saved_data_dictionary().copy()
        new_d = self._get_data_dictionary_to_save().copy()
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043

        old_d.pop('samples')
        old_d.pop('sampler')
        old_d.pop('lnprobs')
        old_d.pop('lnlikes')

        mod_keys = []
        for key in new_d.keys():
            if key in old_d:
                if new_d[key] != old_d[key]:
                    mod_keys.append((key, old_d[key], new_d[key]))
            else:
1044
                raise ValueError('Keys {} not in old dictionary'.format(key))
1045
1046
1047
1048
1049
1050
1051
1052
1053

        if len(mod_keys) == 0:
            return True
        else:
            logging.warning("Saved data differs from requested")
            logging.info("Differences found in following keys:")
            for key in mod_keys:
                if len(key) == 3:
                    if np.isscalar(key[1]) or key[0] == 'nsteps':
1054
                        logging.info("    {} : {} -> {}".format(*key))
1055
                    else:
1056
                        logging.info("    " + key[0])
1057
1058
1059
1060
1061
                else:
                    logging.info(key)
            return False

    def get_max_twoF(self, threshold=0.05):
1062
        """ Returns the max likelihood sample and the corresponding 2F value
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076

        Note: the sample is returned as a dictionary along with an estimate of
        the standard deviation calculated from the std of all samples with a
        twoF within `threshold` (relative) to the max twoF

        """
        if any(np.isposinf(self.lnlikes)):
            logging.info('twoF values contain positive infinite values')
        if any(np.isneginf(self.lnlikes)):
            logging.info('twoF values contain negative infinite values')
        if any(np.isnan(self.lnlikes)):
            logging.info('twoF values contain nan')
        idxs = np.isfinite(self.lnlikes)
        jmax = np.nanargmax(self.lnlikes[idxs])
1077
        maxlogl = self.lnlikes[jmax]
1078
        d = OrderedDict()
1079

1080
1081
        if self.BSGL:
            if hasattr(self, 'search') is False:
1082
                self._initiate_search_object()
1083
1084
1085
1086
1087
1088
1089
            p = self.samples[jmax]
            self.search.BSGL = False
            maxtwoF = self.logl(p, self.search)
            self.search.BSGL = self.BSGL
        else:
            maxtwoF = maxlogl

Gregory Ashton's avatar
Gregory Ashton committed
1090
        repeats = []
1091
        for i, k in enumerate(self.theta_keys):
Gregory Ashton's avatar
Gregory Ashton committed
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
            if k in d and k not in repeats:
                d[k+'_0'] = d[k]  # relabel the old key
                d.pop(k)
                repeats.append(k)
            if k in repeats:
                k = k + '_0'
                count = 1
                while k in d:
                    k = k.replace('_{}'.format(count-1), '_{}'.format(count))
                    count += 1
1102
1103
1104
1105
1106
            d[k] = self.samples[jmax][i]
        return d, maxtwoF

    def get_median_stds(self):
        """ Returns a dict of the median and std of all production samples """
1107
        d = OrderedDict()
Gregory Ashton's avatar
Gregory Ashton committed
1108
        repeats = []
1109
        for s, k in zip(self.samples.T, self.theta_keys):
Gregory Ashton's avatar
Gregory Ashton committed
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
            if k in d and k not in repeats:
                d[k+'_0'] = d[k]  # relabel the old key
                d[k+'_0_std'] = d[k+'_std']
                d.pop(k)
                d.pop(k+'_std')
                repeats.append(k)
            if k in repeats:
                k = k + '_0'
                count = 1
                while k in d:
                    k = k.replace('_{}'.format(count-1), '_{}'.format(count))
                    count += 1

1123
1124
1125
1126
            d[k] = np.median(s)
            d[k+'_std'] = np.std(s)
        return d

1127
    def check_if_samples_are_railing(self, threshold=0.01):
1128
1129
1130
1131
1132
1133
1134
        """ Returns a boolean estimate of if the samples are railing

        Parameters
        ----------
        threshold: float [0, 1]
            Fraction of the uniform prior to test (at upper and lower bound)
        """
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
        return_flag = False
        for s, k in zip(self.samples.T, self.theta_keys):
            prior = self.theta_prior[k]
            if prior['type'] == 'unif':
                prior_range = prior['upper'] - prior['lower']
                edges = []
                fracs = []
                for l in ['lower', 'upper']:
                    bools = np.abs(s - prior[l])/prior_range < threshold
                    if np.any(bools):
                        edges.append(l)
                        fracs.append(str(100*float(np.sum(bools))/len(bools)))
                if len(edges) > 0:
                    logging.warning(
                        '{}% of the {} posterior is railing on the {} edges'
                        .format('% & '.join(fracs), k, ' & '.join(edges)))
                    return_flag = True
        return return_flag

1154
1155
1156
1157
    def write_par(self, method='med'):
        """ Writes a .par of the best-fit params with an estimated std """
        logging.info('Writing {}/{}.par using the {} method'.format(
            self.outdir, self.label, method))
1158
1159
1160
1161

        median_std_d = self.get_median_stds()
        max_twoF_d, max_twoF = self.get_max_twoF()

Gregory Ashton's avatar
Gregory Ashton committed
1162
        logging.info('Writing par file with max twoF = {}'.format(max_twoF))
1163
1164
1165
        filename = '{}/{}.par'.format(self.outdir, self.label)
        with open(filename, 'w+') as f:
            f.write('MaxtwoF = {}\n'.format(max_twoF))
Gregory Ashton's avatar
Gregory Ashton committed
1166
            f.write('tref = {}\n'.format(self.tref))
1167
1168
            if hasattr(self, 'theta0_index'):
                f.write('theta0_index = {}\n'.format(self.theta0_idx))
1169
            if method == 'med':
1170
1171
                for key, val in median_std_d.iteritems():
                    f.write('{} = {:1.16e}\n'.format(key, val))
1172
            if method == 'twoFmax':
1173
1174
1175
                for key, val in max_twoF_d.iteritems():
                    f.write('{} = {:1.16e}\n'.format(key, val))

Gregory Ashton's avatar
Gregory Ashton committed
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
    def write_prior_table(self):
        with open('{}/{}_prior.tex'.format(self.outdir, self.label), 'w') as f:
            f.write(r"\begin{tabular}{c l c} \hline" + '\n'
                    r"Parameter & & &  \\ \hhline{====}")

            for key, prior in self.theta_prior.iteritems():
                if type(prior) is dict:
                    Type = prior['type']
                    if Type == "unif":
                        a = prior['lower']
                        b = prior['upper']
                        line = r"{} & $\mathrm{{Unif}}$({}, {}) & {}\\"
                    elif Type == "norm":
                        a = prior['loc']
                        b = prior['scale']
                        line = r"{} & $\mathcal{{N}}$({}, {}) & {}\\"
                    elif Type == "halfnorm":
                        a = prior['loc']
                        b = prior['scale']
                        line = r"{} & $|\mathcal{{N}}$({}, {})| & {}\\"

                    u = self.unit_dictionary[key]
                    s = self.symbol_dictionary[key]
                    f.write("\n")
                    a = helper_functions.texify_float(a)
                    b = helper_functions.texify_float(b)
                    f.write(" " + line.format(s, a, b, u) + r" \\")
            f.write("\n\end{tabular}\n")

1205
    def print_summary(self):
1206
        """ Prints a summary of the max twoF found to the terminal """
Gregory Ashton's avatar
Gregory Ashton committed
1207
        max_twoFd, max_twoF = self.get_max_twoF()
1208
        median_std_d = self.get_median_stds()
Gregory Ashton's avatar
Gregory Ashton committed
1209
        logging.info('Summary:')
1210
        if hasattr(self, 'theta0_idx'):
Gregory Ashton's avatar
Gregory Ashton committed
1211
1212
            logging.info('theta0 index: {}'.format(self.theta0_idx))
        logging.info('Max twoF: {} with parameters:'.format(max_twoF))
Gregory Ashton's avatar
Gregory Ashton committed
1213
1214
        for k in np.sort(max_twoFd.keys()):
            print('  {:10s} = {:1.9e}'.format(k, max_twoFd[k]))
Gregory Ashton's avatar
Gregory Ashton committed
1215
        logging.info('Median +/- std for production values')
1216
        for k in np.sort(median_std_d.keys()):
1217
            if 'std' not in k:
Gregory Ashton's avatar
Gregory Ashton committed
1218
                logging.info('  {:10s} = {:1.9e} +/- {:1.9e}'.format(
1219
                    k, median_std_d[k], median_std_d[k+'_std']))
Gregory Ashton's avatar
Gregory Ashton committed
1220
        logging.info('\n')
1221

1222
    def _CF_twoFmax(self, theta, twoFmax, ntrials):
1223
1224
1225
1226
        Fmax = twoFmax/2.0
        return (np.exp(1j*theta*twoFmax)*ntrials/2.0
                * Fmax*np.exp(-Fmax)*(1-(1+Fmax)*np.exp(-Fmax))**(ntrials-1))

1227
    def _pdf_twoFhat(self, twoFhat, nglitch, ntrials, twoFmax=100, dtwoF=0.1):
1228
1229
1230
1231
1232
        if np.ndim(ntrials) == 0:
            ntrials = np.zeros(nglitch+1) + ntrials
        twoFmax_int = np.arange(0, twoFmax, dtwoF)
        theta_int = np.arange(-1/dtwoF, 1./dtwoF, 1./twoFmax)
        CF_twoFmax_theta = np.array(
1233
            [[np.trapz(self._CF_twoFmax(t, twoFmax_int, ntrial), twoFmax_int)
1234
1235
1236
1237
1238
1239
1240
1241
              for t in theta_int]
             for ntrial in ntrials])
        CF_twoFhat_theta = np.prod(CF_twoFmax_theta, axis=0)
        pdf = (1/(2*np.pi)) * np.array(
            [np.trapz(np.exp(-1j*theta_int*twoFhat_val)
             * CF_twoFhat_theta, theta_int) for twoFhat_val in twoFhat])
        return pdf.real

1242
    def _p_val_twoFhat(self, twoFhat, ntrials, twoFhatmax=500, Npoints=1000):
1243
        """ Caluculate the p-value for the given twoFhat in Gaussian noise
1244
1245
1246
1247
1248
1249
1250
1251
1252

        Parameters
        ----------
        twoFhat: float
            The observed twoFhat value
        ntrials: int, array of len Nglitch+1
            The number of trials for each glitch+1
        """
        twoFhats = np.linspace(twoFhat, twoFhatmax, Npoints)
1253
        pdf = self._pdf_twoFhat(twoFhats, self.nglitch, ntrials)
1254
1255
1256
1257
1258
1259
1260
1261
        return np.trapz(pdf, twoFhats)

    def get_p_value(self, delta_F0, time_trials=0):
        """ Get's the p-value for the maximum twoFhat value """
        d, max_twoF = self.get_max_twoF()
        if self.nglitch == 1:
            tglitches = [d['tglitch']]
        else:
1262
1263
            tglitches = [d['tglitch_{}'.format(i)]
                         for i in range(self.nglitch)]
1264
        tboundaries = [self.minStartTime] + tglitches + [self.maxStartTime]
1265
        deltaTs = np.diff(tboundaries)
1266
        ntrials = [time_trials + delta_F0 * dT for dT in deltaTs]
1267
        p_val = self._p_val_twoFhat(max_twoF, ntrials)
1268
        print('p-value = {}'.format(p_val))
1269
1270
        return p_val

1271
    def get_evidence(self):
1272
        """ Get the log10 evidence and error estimate """
1273
1274
1275
1276
1277
1278
        fburnin = float(self.nsteps[-2])/np.sum(self.nsteps[-2:])
        lnev, lnev_err = self.sampler.thermodynamic_integration_log_evidence(
            fburnin=fburnin)

        log10evidence = lnev/np.log(10)
        log10evidence_err = lnev_err/np.log(10)
1279
1280
        return log10evidence, log10evidence_err

1281
    def _compute_evidence_long(self):
1282
        """ Computes the evidence/marginal likelihood for the model """
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
        betas = self.betas
        alllnlikes = self.sampler.lnlikelihood[:, :, self.nsteps[-2]:]
        mean_lnlikes = np.mean(np.mean(alllnlikes, axis=1), axis=1)

        mean_lnlikes = mean_lnlikes[::-1]
        betas = betas[::-1]

        fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(6, 8))

        if any(np.isinf(mean_lnlikes)):
            print("WARNING mean_lnlikes contains inf: recalculating without"
                  " the {} infs".format(len(betas[np.isinf(mean_lnlikes)])))
            idxs = np.isinf(mean_lnlikes)
            mean_lnlikes = mean_lnlikes[~idxs]
            betas = betas[~idxs]
            log10evidence = np.trapz(mean_lnlikes, betas)/np.log(10)
            z1 = np.trapz(mean_lnlikes, betas)
            z2 = np.trapz(mean_lnlikes[::-1][::2][::-1],
                          betas[::-1][::2][::-1])
            log10evidence_err = np.abs(z1 - z2) / np.log(10)

        ax1.semilogx(betas, mean_lnlikes, "-o")
        ax1.set_xlabel(r"$\beta$")
        ax1.set_ylabel(r"$\langle \log(\mathcal{L}) \rangle$")
        print("log10 evidence for {} = {} +/- {}".format(
              self.label, log10evidence, log10evidence_err))
        min_betas = []
        evidence = []
        for i in range(len(betas)/2):
            min_betas.append(betas[i])
            lnZ = np.trapz(mean_lnlikes[i:], betas[i:])
            evidence.append(lnZ/np.log(10))

        ax2.semilogx(min_betas, evidence, "-o")
        ax2.set_ylabel(r"$\int_{\beta_{\textrm{Min}}}^{\beta=1}" +
                       r"\langle \log(\mathcal{L})\rangle d\beta$", size=16)
        ax2.set_xlabel(r"$\beta_{\textrm{min}}$")
        plt.tight_layout()
        fig.savefig("{}/{}_beta_lnl.png".format(self.outdir, self.label))

1323

Gregory Ashton's avatar
Gregory Ashton committed
1324
1325
class MCMCGlitchSearch(MCMCSearch):
    """ MCMC search using the SemiCoherentGlitchSearch """
1326
1327
1328
1329
1330
1331
1332
1333

    symbol_dictionary = dict(
        F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
        delta='$\delta$', delta_F0='$\delta f$',
        delta_F1='$\delta \dot{f}$', tglitch='$t_\mathrm{glitch}$')
    unit_dictionary = dict(
        F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
        delta_F0='Hz', delta_F1='Hz/s', tglitch='s')
Gregory Ashton's avatar
Gregory Ashton committed
1334
1335
1336
1337
1338
1339
1340
    rescale_dictionary = dict(
        tglitch={
            'multiplier': 1/86400.,
            'subtractor': 'minStartTime',
            'unit': 'day',
            'label': 'Glitch time \n days after minStartTime'}
            )
1341

Gregory Ashton's avatar
Gregory Ashton committed
1342
    @helper_functions.initializer
1343
    def __init__(self, label, outdir, sftfilepath, theta_prior, tref,
Gregory Ashton's avatar
Gregory Ashton committed
1344
                 minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100],
1345
1346
                 nwalkers=100, ntemps=1, log10temperature_min=-5,
                 theta_initial=None, scatter_val=1e-10, dtglitchmin=1*86400,
1347
                 theta0_idx=0, detectors=None, BSGL=False, minCoverFreq=None,
1348
                 maxCoverFreq=None, earth_ephem=None, sun_ephem=None):
Gregory Ashton's avatar
Gregory Ashton committed
1349
1350
        """
        Parameters
Gregory Ashton's avatar
Gregory Ashton committed
1351
        ----------
Gregory Ashton's avatar
Gregory Ashton committed
1352
1353
        label, outdir: str
            A label and directory to read/write data from/to
Gregory Ashton's avatar
Gregory Ashton committed
1354
        sftfilepath: str
1355
            File patern to match SFTs
Gregory Ashton's avatar
Gregory Ashton committed
1356
1357
1358
1359
1360
1361
1362
1363
        theta_prior: dict
            Dictionary of priors and fixed values for the search parameters.
            For each parameters (key of the dict), if it is to be held fixed
            the value should be the constant float, if it is be searched, the
            value should be a dictionary of the prior.
        theta_initial: dict, array, (None)
            Either a dictionary of distribution about which to distribute the
            initial walkers about, an array (from which the walkers will be
Gregory Ashton's avatar
Gregory Ashton committed
1364
            scattered by scatter_val), or None in which case the prior is used.
1365
1366
1367
1368
        scatter_val, float or ndim array
            Size of scatter to use about the initialisation step, if given as
            an array it must be of length ndim and the order is given by
            theta_keys
Gregory Ashton's avatar
Gregory Ashton committed
1369
1370
        nglitch: int
            The number of glitches to allow
1371
        tref, minStartTime, maxStartTime: int
Gregory Ashton's avatar
Gregory Ashton committed
1372
1373
1374
1375
1376
1377
1378
1379
1380
            GPS seconds of the reference time, start time and end time
        nsteps: list (m,)
            List specifying the number of steps to take, the last two entries
            give the nburn and nprod of the 'production' run, all entries
            before are for iterative initialisation steps (usually just one)
            e.g. [1000, 1000, 500].
        dtglitchmin: int
            The minimum duration (in seconds) of a segment between two glitches
            or a glitch and the start/end of the data
1381
1382
1383
1384
1385
1386
        nwalkers, ntemps: int,
            The number of walkers and temperates to use in the parallel
            tempered PTSampler.
        log10temperature_min float < 0
            The  log_10(tmin) value, the set of betas passed to PTSampler are
            generated from np.logspace(0, log10temperature_min, ntemps).
1387
1388
1389
1390
        theta0_idx, int
            Index (zero-based) of which segment the theta refers to - uyseful
            if providing a tight prior on theta to allow the signal to jump
            too theta (and not just from)
1391
        detectors: str
1392
1393
            Two character reference to the data to use, specify None for no
            contraint.
Gregory Ashton's avatar
Gregory Ashton committed
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
        minCoverFreq, maxCoverFreq: float
            Minimum and maximum instantaneous frequency which will be covered
            over the SFT time span as passed to CreateFstatInput
        earth_ephem, sun_ephem: str
            Paths of the two files containing positions of Earth and Sun,
            respectively at evenly spaced times, as passed to CreateFstatInput
            If None defaults defined in BaseSearchClass will be used

        """

Gregory Ashton's avatar
Gregory Ashton committed
1404
1405
        if os.path.isdir(outdir) is False:
            os.mkdir(outdir)
1406
        self._add_log_file()
Gregory Ashton's avatar
Gregory Ashton committed
1407
1408
        logging.info(('Set-up MCMC glitch search with {} glitches for model {}'
                      ' on data {}').format(self.nglitch, self.label,
1409
                                            self.sftfilepath))
Gregory Ashton's avatar
Gregory Ashton committed
1410
        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
1411
        self._unpack_input_theta()
Gregory Ashton's avatar
Gregory Ashton committed
1412
        self.ndim = len(self.theta_keys)
1413
1414
1415
1416
        if self.log10temperature_min:
            self.betas = np.logspace(0, self.log10temperature_min, self.ntemps)
        else:
            self.betas = None
Gregory Ashton's avatar
Gregory Ashton committed
1417
1418
1419
1420
1421
1422
1423
1424
        if earth_ephem is None:
            self.earth_ephem = self.earth_ephem_default
        if sun_ephem is None:
            self.sun_ephem = self.sun_ephem_default

        if args.clean and os.path.isfile(self.pickle_path):
            os.rename(self.pickle_path, self.pickle_path+".old")

1425
1426
        self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use()
        self._log_input()
Gregory Ashton's avatar
Gregory Ashton committed
1427

1428
    def _initiate_search_object(self):
Gregory Ashton's avatar
Gregory Ashton committed
1429
        logging.info('Setting up search object')
1430
        self.search = core.SemiCoherentGlitchSearch(
1431
            label=self.label, outdir=self.outdir, sftfilepath=self.sftfilepath,
1432
1433
            tref=self.tref, minStartTime=self.minStartTime,
            maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
Gregory Ashton's avatar
Gregory Ashton committed
1434
            maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
1435
            sun_ephem=self.sun_ephem, detectors=self.detectors, BSGL=self.BSGL,
1436
            nglitch=self.nglitch, theta0_idx=self.theta0_idx)
Gregory Ashton's avatar
Gregory Ashton committed
1437
1438
1439

    def logp(self, theta_vals, theta_prior, theta_keys, search):
        if self.nglitch > 1:
1440
1441
            ts = ([self.minStartTime] + list(theta_vals[-self.nglitch:])
                  + [self.maxStartTime])
Gregory Ashton's avatar
Gregory Ashton committed
1442
1443
1444
1445
1446
            if np.array_equal(ts, np.sort(ts)) is False:
                return -np.inf
            if any(np.diff(ts) < self.dtglitchmin):
                return -np.inf

1447
        H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
Gregory Ashton's avatar
Gregory Ashton committed
1448
1449
1450
1451
             zip(theta_vals, theta_keys)]
        return np.sum(H)

    def logl(self, theta, search):
Gregory Ashton's avatar
Gregory Ashton committed
1452
        if self.nglitch > 1:
1453
            ts = ([self.minStartTime] + list(theta[-self.nglitch:])
1454
                  + [self.maxStartTime])
Gregory Ashton's avatar
Gregory Ashton committed
1455
1456
1457
            if np.array_equal(ts, np.sort(ts)) is False:
                return -np.inf

Gregory Ashton's avatar
Gregory Ashton committed
1458
1459
1460
1461
1462
        for j, theta_i in enumerate(self.theta_idxs):
            self.fixed_theta[theta_i] = theta[j]
        FS = search.compute_nglitch_fstat(*self.fixed_theta)
        return FS

1463
    def _unpack_input_theta(self):
Gregory Ashton's avatar
Gregory Ashton committed
1464
1465
1466
        glitch_keys = ['delta_F0', 'delta_F1', 'tglitch']
        full_glitch_keys = list(np.array(
            [[gk]*self.nglitch for gk in glitch_keys]).flatten())
1467
1468
1469
1470

        if 'tglitch_0' in self.theta_prior:
            full_glitch_keys[-self.nglitch:] = [
                'tglitch_{}'.format(i) for i in range(self.nglitch)]
1471
1472
1473
1474
            full_glitch_keys[-2*self.nglitch:-1*self.nglitch] = [
                'delta_F1_{}'.format(i) for i in range(self.nglitch)]
            full_glitch_keys[-4*self.nglitch:-2*self.nglitch] = [
                'delta_F0_{}'.format(i) for i in range(self.nglitch)]
Gregory Ashton's avatar
Gregory Ashton committed
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
        full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']+full_glitch_keys
        full_theta_keys_copy = copy.copy(full_theta_keys)

        glitch_symbols = ['$\delta f$', '$\delta \dot{f}$', r'$t_{glitch}$']
        full_glitch_symbols = list(np.array(
            [[gs]*self.nglitch for gs in glitch_symbols]).flatten())
        full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
                               r'$\delta$'] + full_glitch_symbols)
        self.theta_keys = []
        fixed_theta_dict = {}
        for key, val in self.theta_prior.iteritems():
            if type(val) is dict:
                fixed_theta_dict[key] = 0
                if key in glitch_keys:
                    for i in range(self.nglitch):
                        self.theta_keys.append(key)
                else:
                    self.theta_keys.append(key)
            elif type(val) in [float, int, np.float64]:
                fixed_theta_dict[key] = val
            else:
                raise ValueError(
                    'Type {} of {} in theta not recognised'.format(
                        type(val), key))
            if key in glitch_keys:
                for i in range(self.nglitch):
                    full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
            else:
                full_theta_keys_copy.pop(full_theta_keys_copy.index(key))

        if len(full_theta_keys_copy) > 0:
            raise ValueError(('Input dictionary `theta` is missing the'
                              'following keys: {}').format(
                                  full_theta_keys_copy))

        self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
        self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
        self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]

        idxs = np.argsort(self.theta_idxs)
        self.theta_idxs = [self.theta_idxs[i] for i in idxs]
        self.theta_symbols = [self.theta_symbols[i] for i in idxs]
        self.theta_keys = [self.theta_keys[i] for i in idxs]

        # Correct for number of glitches in the idxs
        self.theta_idxs = np.array(self.theta_idxs)
        while np.sum(self.theta_idxs[:-1] == self.theta_idxs[1:]) > 0:
            for i, idx in enumerate(self.theta_idxs):
                if idx in self.theta_idxs[:i]:
                    self.theta_idxs[i] += 1

1526
    def _get_data_dictionary_to_save(self):
1527
1528
1529
1530
1531
1532
1533
        d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
                 ntemps=self.ntemps, theta_keys=self.theta_keys,
                 theta_prior=self.theta_prior, scatter_val=self.scatter_val,
                 log10temperature_min=self.log10temperature_min,
                 theta0_idx=self.theta0_idx, BSGL=self.BSGL)
        return d

1534
    def _apply_corrections_to_p0(self, p0):
Gregory Ashton's avatar
Gregory Ashton committed
1535
1536
1537
1538
1539
1540
        p0 = np.array(p0)
        if self.nglitch > 1:
            p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:],
                                               axis=2)
        return p0

Gregory Ashton's avatar
Gregory Ashton committed
1541
1542
1543
1544
1545
1546
1547
1548
    def plot_cumulative_max(self):

        fig, ax = plt.subplots()
        d, maxtwoF = self.get_max_twoF()
        for key, val in self.theta_prior.iteritems():
            if key not in d:
                d[key] = val

1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
        if self.nglitch > 1:
            delta_F0s = [d['delta_F0_{}'.format(i)] for i in
                         range(self.nglitch)]
            delta_F0s.insert(self.theta0_idx, 0)
            delta_F0s = np.array(delta_F0s)
            delta_F0s[:self.theta0_idx] *= -1
            tglitches = [d['tglitch_{}'.format(i)] for i in
                         range(self.nglitch)]
        elif self.nglitch == 1:
            delta_F0s = [d['delta_F0']]
            delta_F0s.insert(self.theta0_idx, 0)
            delta_F0s = np.array(delta_F0s)
            delta_F0s[:self.theta0_idx] *= -1
            tglitches = [d['tglitch']]
Gregory Ashton's avatar
Gregory Ashton committed
1563

1564
        tboundaries = [self.minStartTime] + tglitches + [self.maxStartTime]
Gregory Ashton's avatar
Gregory Ashton committed
1565
1566

        for j in range(self.nglitch+1):
1567
1568
            ts = tboundaries[j]
            te = tboundaries[j+1]
Gregory Ashton's avatar
Gregory Ashton committed
1569
1570
1571
1572
1573
1574
1575
1576
            if (te - ts)/86400 < 5:
                logging.info('Period too short to perform cumulative search')
                continue
            if j < self.theta0_idx:
                summed_deltaF0 = np.sum(delta_F0s[j:self.theta0_idx])
                F0_j = d['F0'] - summed_deltaF0
                taus, twoFs = self.search.calculate_twoF_cumulative(
                    F0_j, F1=d['F1'], F2=d['F2'], Alpha=d['Alpha'],
1577
                    Delta=d['Delta'], tstart=ts, tend=te)
Gregory Ashton's avatar
Gregory Ashton committed
1578
1579
1580
1581
1582
1583

            elif j >= self.theta0_idx:
                summed_deltaF0 = np.sum(delta_F0s[self.theta0_idx:j+1])
                F0_j = d['F0'] + summed_deltaF0
                taus, twoFs = self.search.calculate_twoF_cumulative(
                    F0_j, F1=d['F1'], F2=d['F2'], Alpha=d['Alpha'],
1584
                    Delta=d['Delta'], tstart=ts, tend=te)
Gregory Ashton's avatar
Gregory Ashton committed
1585
1586
1587
1588
1589
            ax.plot(ts+taus, twoFs)

        ax.set_xlabel('GPS time')
        fig.savefig('{}/{}_twoFcumulative.png'.format(self.outdir, self.label))

Gregory Ashton's avatar
Gregory Ashton committed
1590

1591
1592
class MCMCSemiCoherentSearch(MCMCSearch):
    """ MCMC search for a signal using the semi-coherent ComputeFstat """
Gregory Ashton's avatar
Gregory Ashton committed
1593
    @helper_functions.initializer
1594
    def __init__(self, label, outdir, theta_prior, tref, sftfilepath=None,
1595
1596
1597
1598
1599
1600
                 nsegs=None, nsteps=[100, 100, 100], nwalkers=100,
                 binary=False, ntemps=1, log10temperature_min=-5,
                 theta_initial=None, scatter_val=1e-10, detectors=None,
                 BSGL=False, minStartTime=None, maxStartTime=None,
                 minCoverFreq=None, maxCoverFreq=None, earth_ephem=None,
                 sun_ephem=None, injectSources=None, assumeSqrtSX=None):
1601
1602
1603
1604
1605
1606
        """

        """

        if os.path.isdir(outdir) is False:
            os.mkdir(outdir)
1607
        self._add_log_file()
1608
1609
1610
1611
        logging.info(('Set-up MCMC semi-coherent search for model {} on data'
                      '{}').format(
            self.label, self.sftfilepath))
        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
1612
        self._unpack_input_theta()
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
        self.ndim = len(self.theta_keys)
        if self.log10temperature_min:
            self.betas = np.logspace(0, self.log10temperature_min, self.ntemps)
        else:
            self.betas = None
        if earth_ephem is None:
            self.earth_ephem = self.earth_ephem_default
        if sun_ephem is None:
            self.sun_ephem = self.sun_ephem_default

        if args.clean and os.path.isfile(self.pickle_path):
            os.rename(self.pickle_path, self.pickle_path+".old")

1626
        self._log_input()
1627

1628
    def _get_data_dictionary_to_save(self):
1629
1630
1631
1632
1633
1634
1635
        d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
                 ntemps=self.ntemps, theta_keys=self.theta_keys,
                 theta_prior=self.theta_prior, scatter_val=self.scatter_val,
                 log10temperature_min=self.log10temperature_min,
                 BSGL=self.BSGL, nsegs=self.nsegs)
        return d

1636
    def _initiate_search_object(self):
1637
        logging.info('Setting up search object')
1638
        self.search = core.SemiCoherentSearch(
1639
1640
1641
1642
            label=self.label, outdir=self.outdir, tref=self.tref,
            nsegs=self.nsegs, sftfilepath=self.sftfilepath, binary=self.binary,
            BSGL=self.BSGL, minStartTime=self.minStartTime,
            maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
1643
            maxCoverFreq=self.maxCoverFreq, detectors=self.detectors,
1644
            earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
1645
            injectSources=self.injectSources, assumeSqrtSX=self.assumeSqrtSX)
1646
1647

    def logp(self, theta_vals, theta_prior, theta_keys, search):
1648
        H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
1649
1650
1651
1652
1653
1654
             zip(theta_vals, theta_keys)]
        return np.sum(H)

    def logl(self, theta, search):
        for j, theta_i in enumerate(self.theta_idxs):
            self.fixed_theta[theta_i] = theta[j]
Gregory Ashton's avatar
Gregory Ashton committed
1655
1656
        FS = search.run_semi_coherent_computefstatistic_single_point(
            *self.fixed_theta)
1657
1658
1659
        return FS


Gregory Ashton's avatar
Gregory Ashton committed
1660
1661
class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
    """ A follow up procudure increasing the coherence time in a zoom """
1662
    def _get_data_dictionary_to_save(self):
Gregory Ashton's avatar
Gregory Ashton committed
1663
1664
1665
1666
1667
1668
1669
        d = dict(nwalkers=self.nwalkers, ntemps=self.ntemps,
                 theta_keys=self.theta_keys, theta_prior=self.theta_prior,
                 scatter_val=self.scatter_val,
                 log10temperature_min=self.log10temperature_min,
                 BSGL=self.BSGL, run_setup=self.run_setup)
        return d

Gregory Ashton's avatar
Gregory Ashton committed
1670
1671
1672
1673
    def update_search_object(self):
        logging.info('Update search object')
        self.search.init_computefstatistic_single_point()

1674
1675
1676
1677
1678
1679
1680
1681
    def get_width_from_prior(self, prior, key):
        if prior[key]['type'] == 'unif':
            return prior[key]['upper'] - prior[key]['lower']

    def get_mid_from_prior(self, prior, key):
        if prior[key]['type'] == 'unif':
            return .5*(prior[key]['upper'] + prior[key]['lower'])

1682
    def init_V_estimate_parameters(self):
1683
1684
1685
1686
        if 'Alpha' in self.theta_keys:
            DeltaAlpha = self.get_width_from_prior(self.theta_prior, 'Alpha')
            DeltaDelta = self.get_width_from_prior(self.theta_prior, 'Delta')
            DeltaMid = self.get_mid_from_prior(self.theta_prior, 'Delta')
1687
            DeltaOmega = np.sin(np.pi/2 - DeltaMid) * DeltaDelta * DeltaAlpha
Gregory Ashton's avatar
Gregory Ashton committed
1688
            logging.info('Search over Alpha and Delta')
1689
        else:
Gregory Ashton's avatar
Gregory Ashton committed
1690
            logging.info('No sky search requested')
1691
            DeltaOmega = 0
1692
1693
1694
        if 'F0' in self.theta_keys:
            DeltaF0 = self.get_width_from_prior(self.theta_prior, 'F0')
        else:
Gregory Ashton's avatar
Gregory Ashton committed
1695
            raise ValueError("You aren't searching over F0?")
1696
        DeltaFs = [DeltaF0]
1697
1698
        if 'F1' in self.theta_keys:
            DeltaF1 = self.get_width_from_prior(self.theta_prior, 'F1')
1699
            DeltaFs.append(DeltaF1)
Gregory Ashton's avatar
Gregory Ashton committed
1700
1701
1702
1703
1704
            if 'F2' in self.theta_keys:
                DeltaF2 = self.get_width_from_prior(self.theta_prior, 'F2')
                DeltaFs.append(DeltaF2)
        logging.info('Searching over Frequency and {} spin-down components'
                     .format(len(DeltaFs)-1))
1705

1706
1707
1708
1709
1710
        if type(self.theta_prior['F0']) == dict:
            fiducial_freq = self.get_mid_from_prior(self.theta_prior, 'F0')
        else:
            fiducial_freq = self.theta_prior['F0']

1711
        return fiducial_freq, DeltaOmega, DeltaFs
1712

1713
1714
1715
1716
1717
    def read_setup_input_file(self, run_setup_input_file):
        with open(run_setup_input_file, 'r+') as f:
            d = pickle.load(f)
        return d

1718
    def write_setup_input_file(self, run_setup_input_file, R, Nsegs0,
1719
                               nsegs_vals, V_vals, DeltaOmega, DeltaFs):
1720
        d = dict(R=R, Nsegs0=Nsegs0, nsegs_vals=nsegs_vals, V_vals=V_vals,
1721
                 DeltaOmega=DeltaOmega, DeltaFs=DeltaFs)
1722
1723
1724
        with open(run_setup_input_file, 'w+') as f:
            pickle.dump(d, f)

1725
1726
    def check_old_run_setup(self, old_setup, **kwargs):
        try:
1727
1728
            truths = [val == old_setup[key] for key, val in kwargs.iteritems()]
            return all(truths)
1729
1730
1731
        except KeyError:
            return False

1732
1733
1734
1735
1736
1737
1738
    def init_run_setup(self, run_setup=None, R=10, Nsegs0=None, log_table=True,
                       gen_tex_table=True):

        if run_setup is None and Nsegs0 is None:
            raise ValueError(
                'You must either specify the run_setup, or Nsegs0 from which '
                'the optimial run_setup given R can be estimated')
1739
        fiducial_freq, DeltaOmega, DeltaFs = self.init_V_estimate_parameters()
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
        if run_setup is None:
            logging.info('No run_setup provided')

            run_setup_input_file = '{}/{}_run_setup.p'.format(
                self.outdir, self.label)

            if os.path.isfile(run_setup_input_file):
                logging.info('Checking old setup input file {}'.format(
                    run_setup_input_file))
                old_setup = self.read_setup_input_file(run_setup_input_file)
1750
1751
                if self.check_old_run_setup(old_setup, R=R,
                                            Nsegs0=Nsegs0,
1752
1753
                                            DeltaOmega=DeltaOmega,
                                            DeltaFs=DeltaFs):
1754
1755
                    logging.info('Using old setup with R={}, Nsegs0={}'.format(
                        R, Nsegs0))
1756
                    nsegs_vals = old_setup['nsegs_vals']
1757
                    V_vals = old_setup['V_vals']
1758
                    generate_setup = False
1759
                else:
1760
1761
                    logging.info(
                        'Old setup does not match requested R, Nsegs0')
1762
1763
1764
                    generate_setup = True
            else:
                generate_setup = True
1765

1766
            if generate_setup:
1767
                nsegs_vals, V_vals = get_optimal_setup(
1768
                    R, Nsegs0, self.tref, self.minStartTime,
1769
1770
1771
                    self.maxStartTime, DeltaOmega, DeltaFs, fiducial_freq,
                    self.search.detector_names, self.earth_ephem,
                    self.sun_ephem)
1772
                self.write_setup_input_file(run_setup_input_file, R, Nsegs0,
1773
1774
                                            nsegs_vals, V_vals, DeltaOmega,
                                            DeltaFs)
1775
1776
1777
1778
1779

            run_setup = [((self.nsteps[0], 0),  nsegs, False)
                         for nsegs in nsegs_vals[:-1]]
            run_setup.append(
                ((self.nsteps[0], self.nsteps[1]), nsegs_vals[-1], False))
1780
1781
1782
1783

        else:
            logging.info('Calculating the number of templates for this setup')
            V_vals = []
1784
1785
1786
1787
1788
1789
1790
1791
            for i, rs in enumerate(run_setup):
                rs = list(rs)
                if len(rs) == 2:
                    rs.append(False)
                if np.shape(rs[0]) == ():
                    rs[0] = (rs[0], 0)
                run_setup[i] = rs

1792
1793
1794
                if args.no_template_counting:
                    V_vals.append([1, 1, 1])
                else:
1795
1796
1797
1798
1799
1800
                    V, Vsky, Vpe = get_V_estimate(
                        rs[1], self.tref, self.minStartTime, self.maxStartTime,
                        DeltaOmega, DeltaFs, fiducial_freq,
                        self.search.detector_names, self.earth_ephem,
                        self.sun_ephem)
                    V_vals.append([V, Vsky, Vpe])
1801
1802

        if log_table:
1803
            logging.info('Using run-setup as follows:')
1804
            logging.info('Stage | nburn | nprod | nsegs | Tcoh d | resetp0 |'
1805
                         ' V = Vsky x Vpe')
1806
            for i, rs in enumerate(run_setup):
1807
                Tcoh = (self.maxStartTime - self.minStartTime) / rs[1] / 86400
1808
                if V_vals[i] is None:
1809
                    vtext = 'N/A'
1810
                else:
1811
                    vtext = '{:1.0e} = {:1.0e} x {:1.0e}'.format(
1812
                            V_vals[i][0], V_vals[i][1], V_vals[i][2])
1813
                logging.info('{} | {} | {} | {} | {} | {} | {}'.format(
1814
1815
                    str(i).ljust(5), str(rs[0][0]).ljust(5),
                    str(rs[0][1]).ljust(5), str(rs[1]).ljust(5),
1816
1817
                    '{:6.1f}'.format(Tcoh), str(rs[2]).ljust(7),
                    vtext))
1818
1819
1820

        if gen_tex_table:
            filename = '{}/{}_run_setup.tex'.format(self.outdir, self.label)
1821
1822
1823
1824
1825
1826
1827
            if DeltaOmega > 0:
                with open(filename, 'w+') as f:
                    f.write(r'\begin{tabular}{c|cccccc}' + '\n')
                    f.write(r'Stage & $\Nseg$ & $\Tcoh^{\rm days}$ &'
                            r'$\Nsteps$ & $\V$ & $\Vsky$ & $\Vpe$ \\ \hline'
                            '\n')
                    for i, rs in enumerate(run_setup):
1828
1829
                        Tcoh = float(
                            self.maxStartTime - self.minStartTime)/rs[1]/86400
1830
                        line = r'{} & {} & {} & {} & {} & {} & {} \\' + '\n'
1831
1832
1833
1834
                        if V_vals[i][0] is None:
                            V = Vsky = Vpe = 'N/A'
                        else:
                            V, Vsky, Vpe = V_vals[i]
1835
1836
1837
1838
                        if rs[0][-1] == 0:
                            nsteps = rs[0][0]
                        else:
                            nsteps = '{},{}'.format(*rs[0])
1839
                        line = line.format(i, rs[1], '{:1.1f}'.format(Tcoh),
Gregory Ashton's avatar
Gregory Ashton committed
1840
1841
1842
1843
                                           nsteps,
                                           helper_functions.texify_float(V),
                                           helper_functions.texify_float(Vsky),
                                           helper_functions.texify_float(Vpe))
1844
1845
1846
1847
1848
1849
1850
1851
1852
                        f.write(line)
                    f.write(r'\end{tabular}' + '\n')
            else:
                with open(filename, 'w+') as f:
                    f.write(r'\begin{tabular}{c|cccc}' + '\n')
                    f.write(r'Stage & $\Nseg$ & $\Tcoh^{\rm days}$ &'
                            r'$\Nsteps$ & $\Vpe$ \\ \hline'
                            '\n')
                    for i, rs in enumerate(run_setup):
1853
1854
                        Tcoh = float(
                            self.maxStartTime - self.minStartTime)/rs[1]/86400
1855
                        line = r'{} & {} & {} & {} & {} \\' + '\n'
1856
1857
1858
1859
                        if V_vals[i] is None:
                            V = Vsky = Vpe = 'N/A'
                        else:
                            V, Vsky, Vpe = V_vals[i]
1860
1861
1862
1863
                        if rs[0][-1] == 0:
                            nsteps = rs[0][0]
                        else:
                            nsteps = '{},{}'.format(*rs[0])
1864
                        line = line.format(i, rs[1], '{:1.1f}'.format(Tcoh),
Gregory Ashton's avatar
Gregory Ashton committed
1865
1866
                                           nsteps,
                                           helper_functions.texify_float(Vpe))
1867
1868
                        f.write(line)
                    f.write(r'\end{tabular}' + '\n')
1869
1870
1871
1872

        if args.setup_only:
            logging.info("Exit as requested by setup_only flag")
            sys.exit()
1873
1874
1875
        else:
            return run_setup

1876
    def run(self, run_setup=None, proposal_scale_factor=2, R=10, Nsegs0=None,
1877
1878
            create_plots=True, log_table=True, gen_tex_table=True, fig=None,
            axes=None, return_fig=False, **kwargs):
1879
1880
1881
1882
1883
1884
1885
1886
1887
        """ Run the follow-up with the given run_setup

        Parameters
        ----------
        run_setup: list of tuples

        """

        self.nsegs = 1
1888
        self._initiate_search_object()
1889
        run_setup = self.init_run_setup(
1890
            run_setup, R=R, Nsegs0=Nsegs0, log_table=log_table,
1891
            gen_tex_table=gen_tex_table)
1892
        self.run_setup = run_setup
Gregory Ashton's avatar
Gregory Ashton committed
1893

1894
        self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use()
Gregory Ashton's avatar
Gregory Ashton committed
1895
1896
1897
        if self.old_data_is_okay_to_use is True:
            logging.warning('Using saved data from {}'.format(
                self.pickle_path))
1898
            d = self.get_saved_data_dictionary()
Gregory Ashton's avatar
Gregory Ashton committed
1899
1900
1901
1902
            self.sampler = d['sampler']
            self.samples = d['samples']
            self.lnprobs = d['lnprobs']
            self.lnlikes = d['lnlikes']
1903
            self.nsegs = run_setup[-1][1]
Gregory Ashton's avatar
Gregory Ashton committed
1904
1905
1906
1907
1908
            return

        nsteps_total = 0
        for j, ((nburn, nprod), nseg, reset_p0) in enumerate(run_setup):
            if j == 0:
1909
1910
                p0 = self._generate_initial_p0()
                p0 = self._apply_corrections_to_p0(p0)
Gregory Ashton's avatar
Gregory Ashton committed
1911
            elif reset_p0:
1912
1913
1914
                p0 = self._get_new_p0(sampler)
                p0 = self._apply_corrections_to_p0(p0)
                # self._check_initial_points(p0)
Gregory Ashton's avatar
Gregory Ashton committed
1915
1916
1917
1918
            else:
                p0 = sampler.chain[:, :, -1, :]

            self.nsegs = nseg
1919
            self.search.nsegs = nseg
Gregory Ashton's avatar
Gregory Ashton committed
1920
            self.update_search_object()
1921
            self.search.init_semicoherent_parameters()
Gregory Ashton's avatar
Gregory Ashton committed
1922
1923
1924
1925
1926
1927
            sampler = emcee.PTSampler(
                self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
                logpargs=(self.theta_prior, self.theta_keys, self.search),
                loglargs=(self.search,), betas=self.betas,
                a=proposal_scale_factor)

1928
            Tcoh = (self.maxStartTime-self.minStartTime)/nseg/86400.
Gregory Ashton's avatar
Gregory Ashton committed
1929
1930
            logging.info(('Running {}/{} with {} steps and {} nsegs '
                          '(Tcoh={:1.2f} days)').format(
1931
                j+1, len(run_setup), (nburn, nprod), nseg, Tcoh))
1932
            sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
Gregory Ashton's avatar
Gregory Ashton committed
1933
1934
1935
1936
1937
            logging.info("Mean acceptance fraction: {}"
                         .format(np.mean(sampler.acceptance_fraction, axis=1)))
            if self.ntemps > 1:
                logging.info("Tswap acceptance fraction: {}"
                             .format(sampler.tswap_acceptance_fraction))
1938
1939
            logging.info('Max detection statistic of run was {}'.format(
                np.max(sampler.lnlikelihood)))
Gregory Ashton's avatar
Gregory Ashton committed
1940

1941
            if create_plots:
1942
                fig, axes = self._plot_walkers(
1943
                    sampler, symbols=self.theta_symbols, fig=fig, axes=axes,
1944
                    nprod=nprod, xoffset=nsteps_total, **kwargs)
1945
1946
                for ax in axes[:self.ndim]:
                    ax.axvline(nsteps_total, color='k', ls='--', lw=0.25)
Gregory Ashton's avatar
Gregory Ashton committed
1947

1948
            nsteps_total += nburn+nprod
Gregory Ashton's avatar
Gregory Ashton committed
1949
1950
1951
1952
1953
1954
1955
1956

        samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
        lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1))
        lnlikes = sampler.lnlikelihood[0, :, nburn:].reshape((-1))
        self.sampler = sampler
        self.samples = samples
        self.lnprobs = lnprobs
        self.lnlikes = lnlikes
1957
        self._save_data(sampler, samples, lnprobs, lnlikes)
Gregory Ashton's avatar
Gregory Ashton committed
1958

1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
        if create_plots:
            try:
                fig.tight_layout()
            except (ValueError, RuntimeError) as e:
                logging.warning('Tight layout encountered {}'.format(e))
            if return_fig:
                return fig, axes
            else:
                fig.savefig('{}/{}_walkers.png'.format(
                    self.outdir, self.label), dpi=200)

Gregory Ashton's avatar
Gregory Ashton committed
1970

Gregory Ashton's avatar
Gregory Ashton committed
1971
1972
1973
class MCMCTransientSearch(MCMCSearch):
    """ MCMC search for a transient signal using the ComputeFstat """

Gregory Ashton's avatar
Gregory Ashton committed
1974
1975
    symbol_dictionary = dict(
        F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$',
1976
1977
        alpha=r'$\alpha$', delta='$\delta$',
        transient_tstart='$t_\mathrm{start}$', transient_duration='$\Delta T$')
Gregory Ashton's avatar
Gregory Ashton committed
1978
1979
    unit_dictionary = dict(
        F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
1980
        transient_tstart='s', transient_duration='s')
Gregory Ashton's avatar
Gregory Ashton committed
1981

1982
1983
    rescale_dictionary = dict(
        transient_duration={'multiplier': 1/86400.,
1984
1985
                            'unit': 'day',
                            'symbol': 'Transient duration'},
1986
1987
        transient_tstart={
            'multiplier': 1/86400.,
1988
1989
            'subtractor': 'minStartTime',
            'unit': 'day',
1990
1991
1992
            'label': 'Transient start-time \n days after minStartTime'}
            )

1993
    def _initiate_search_object(self):
Gregory Ashton's avatar
Gregory Ashton committed
1994
        logging.info('Setting up search object')
1995
        self.search = core.ComputeFstat(
Gregory Ashton's avatar
Gregory Ashton committed
1996
1997
1998
            tref=self.tref, sftfilepath=self.sftfilepath,
            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
            earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
1999
            detectors=self.detectors, transient=True,
Gregory Ashton's avatar
Gregory Ashton committed
2000
            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,