mcmc_based_searches.py 92 KB
Newer Older
1001
1002
            if ndim > 1:
                for i in range(ndim):
1003
                    axes[i].ticklabel_format(useOffset=False, axis='y')
Gregory Ashton's avatar
Gregory Ashton committed
1004
                    cs = chain[:, :, i].T
1005
                    if burnin_idx > 0:
1006
1007
                        axes[i].plot(xoffset+idxs[:convergence_idx+1],
                                     cs[:convergence_idx+1]-subtractions[i],
1008
                                     color="C3", alpha=alpha,
Gregory Ashton's avatar
Gregory Ashton committed
1009
                                     lw=lw)
1010
                        axes[i].axvline(xoffset+convergence_idx,
1011
                                        color='k', ls='--', lw=0.5)
1012
1013
                    axes[i].plot(xoffset+idxs[burnin_idx:],
                                 cs[burnin_idx:]-subtractions[i],
Gregory Ashton's avatar
Gregory Ashton committed
1014
                                 color="k", alpha=alpha, lw=lw)
Gregory Ashton's avatar
Gregory Ashton committed
1015
1016

                    axes[i].set_xlim(0, xoffset+idxs[-1])
1017
                    if symbols:
1018
                        if subtractions[i] == 0:
1019
                            axes[i].set_ylabel(symbols[i], labelpad=labelpad)
1020
1021
                        else:
                            axes[i].set_ylabel(
1022
1023
                                symbols[i]+'$-$'+symbols[i]+'$_0$',
                                labelpad=labelpad)
1024

1025
1026
                    if hasattr(self, 'convergence_diagnostic'):
                        ax = axes[i].twinx()
1027
1028
                        axes[i].set_zorder(ax.get_zorder()+1)
                        axes[i].patch.set_visible(False)
1029
1030
                        c_x = np.array(self.convergence_diagnosticx)
                        c_y = np.array(self.convergence_diagnostic)
1031
                        break_idx = np.argmin(np.abs(c_x - burnin_idx))
1032
1033
1034
1035
                        ax.plot(c_x[:break_idx], c_y[:break_idx, i], '-C0',
                                zorder=-10)
                        ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-C0',
                                zorder=-10)
1036
1037
1038
1039
                        if self.convergence_test_type == 'autocorr':
                            ax.set_ylabel(r'$\tau_\mathrm{exp}$')
                        elif self.convergence_test_type == 'GR':
                            ax.set_ylabel('PSRF')
1040
                        ax.ticklabel_format(useOffset=False)
1041
            else:
Gregory Ashton's avatar
Gregory Ashton committed
1042
                axes[0].ticklabel_format(useOffset=False, axis='y')
Gregory Ashton's avatar
Gregory Ashton committed
1043
                cs = chain[:, :, temp].T
Gregory Ashton's avatar
Gregory Ashton committed
1044
1045
                if burnin_idx:
                    axes[0].plot(idxs[:burnin_idx], cs[:burnin_idx],
1046
                                 color="C3", alpha=alpha, lw=lw)
Gregory Ashton's avatar
Gregory Ashton committed
1047
1048
1049
                axes[0].plot(idxs[burnin_idx:], cs[burnin_idx:], color="k",
                             alpha=alpha, lw=lw)
                if symbols:
1050
                    axes[0].set_ylabel(symbols[0], labelpad=labelpad)
1051

Gregory Ashton's avatar
Gregory Ashton committed
1052
1053
            axes[-1].set_xlabel(r'$\textrm{Number of steps}$', labelpad=0.2)

1054
            if plot_det_stat:
1055
1056
1057
                if len(axes) == ndim:
                    axes.append(fig.add_subplot(ndim+1, 1, ndim+1))

1058
1059
1060
                lnl = sampler.lnlikelihood[temp, :, :]
                if burnin_idx and add_det_stat_burnin:
                    burn_in_vals = lnl[:, :burnin_idx].flatten()
1061
                    try:
1062
1063
1064
1065
                        twoF_burnin = (burn_in_vals[~np.isnan(burn_in_vals)]
                                       - self.likelihoodcoef)
                        axes[-1].hist(twoF_burnin, bins=50, histtype='step',
                                      color='C3')
1066
1067
1068
1069
                    except ValueError:
                        logging.info('Det. Stat. hist failed, most likely all '
                                     'values where the same')
                        pass
1070
                else:
1071
                    twoF_burnin = []
1072
                prod_vals = lnl[:, burnin_idx:].flatten()
1073
                try:
1074
1075
                    twoF = prod_vals[~np.isnan(prod_vals)]-self.likelihoodcoef
                    axes[-1].hist(twoF, bins=50, histtype='step', color='k')
1076
1077
1078
1079
                except ValueError:
                    logging.info('Det. Stat. hist failed, most likely all '
                                 'values where the same')
                    pass
1080
1081
1082
1083
1084
                if self.BSGL:
                    axes[-1].set_xlabel(r'$\mathcal{B}_\mathrm{S/GL}$')
                else:
                    axes[-1].set_xlabel(r'$\widetilde{2\mathcal{F}}$')
                axes[-1].set_ylabel(r'$\textrm{Counts}$')
1085
                combined_vals = np.append(twoF_burnin, twoF)
1086
1087
1088
1089
1090
1091
                if len(combined_vals) > 0:
                    minv = np.min(combined_vals)
                    maxv = np.max(combined_vals)
                    Range = abs(maxv-minv)
                    axes[-1].set_xlim(minv-0.1*Range, maxv+0.1*Range)

1092
                xfmt = matplotlib.ticker.ScalarFormatter()
1093
                xfmt.set_powerlimits((-4, 4))
1094
1095
                axes[-1].xaxis.set_major_formatter(xfmt)

1096
1097
        return fig, axes

1098
    def _apply_corrections_to_p0(self, p0):
Gregory Ashton's avatar
Gregory Ashton committed
1099
1100
1101
        """ Apply any correction to the initial p0 values """
        return p0

1102
    def _generate_scattered_p0(self, p):
1103
        """ Generate a set of p0s scattered about p """
Gregory Ashton's avatar
Gregory Ashton committed
1104
        p0 = [[p + self.scatter_val * p * np.random.randn(self.ndim)
1105
1106
1107
1108
               for i in xrange(self.nwalkers)]
              for j in xrange(self.ntemps)]
        return p0

1109
    def _generate_initial_p0(self):
1110
1111
1112
        """ Generate a set of init vals for the walkers """

        if type(self.theta_initial) == dict:
1113
            logging.info('Generate initial values from initial dictionary')
1114
            if hasattr(self, 'nglitch') and self.nglitch > 1:
1115
                raise ValueError('Initial dict not implemented for nglitch>1')
1116
            p0 = [[[self._generate_rv(**self.theta_initial[key])
1117
1118
1119
1120
                    for key in self.theta_keys]
                   for i in range(self.nwalkers)]
                  for j in range(self.ntemps)]
        elif self.theta_initial is None:
1121
            logging.info('Generate initial values from prior dictionary')
1122
            p0 = [[[self._generate_rv(**self.theta_prior[key])
1123
1124
1125
1126
1127
1128
1129
1130
                    for key in self.theta_keys]
                   for i in range(self.nwalkers)]
                  for j in range(self.ntemps)]
        else:
            raise ValueError('theta_initial not understood')

        return p0

1131
    def _get_new_p0(self, sampler):
1132
1133
1134
1135
1136
1137
        """ Returns new initial positions for walkers are burn0 stage

        This returns new positions for all walkers by scattering points about
        the maximum posterior with scale `scatter_val`.

        """
Gregory Ashton's avatar
Gregory Ashton committed
1138
1139
1140
1141
        temp_idx = 0
        pF = sampler.chain[temp_idx, :, :, :]
        lnl = sampler.lnlikelihood[temp_idx, :, :]
        lnp = sampler.lnprobability[temp_idx, :, :]
1142
1143

        # General warnings about the state of lnp
Gregory Ashton's avatar
Gregory Ashton committed
1144
        if np.any(np.isnan(lnp)):
1145
1146
            logging.warning(
                "Of {} lnprobs {} are nan".format(
Gregory Ashton's avatar
Gregory Ashton committed
1147
1148
                    np.shape(lnp), np.sum(np.isnan(lnp))))
        if np.any(np.isposinf(lnp)):
1149
1150
            logging.warning(
                "Of {} lnprobs {} are +np.inf".format(
Gregory Ashton's avatar
Gregory Ashton committed
1151
1152
                    np.shape(lnp), np.sum(np.isposinf(lnp))))
        if np.any(np.isneginf(lnp)):
1153
1154
            logging.warning(
                "Of {} lnprobs {} are -np.inf".format(
Gregory Ashton's avatar
Gregory Ashton committed
1155
                    np.shape(lnp), np.sum(np.isneginf(lnp))))
1156

1157
1158
        lnp_finite = copy.copy(lnp)
        lnp_finite[np.isinf(lnp)] = np.nan
Gregory Ashton's avatar
Gregory Ashton committed
1159
1160
        idx = np.unravel_index(np.nanargmax(lnp_finite), lnp_finite.shape)
        p = pF[idx]
1161
        p0 = self._generate_scattered_p0(p)
1162

1163
1164
1165
1166
1167
1168
1169
1170
        self.search.BSGL = False
        twoF = self.logl(p, self.search)
        self.search.BSGL = self.BSGL

        logging.info(('Gen. new p0 from pos {} which had det. stat.={:2.1f},'
                      ' twoF={:2.1f} and lnp={:2.1f}')
                     .format(idx[1], lnl[idx], twoF, lnp_finite[idx]))

1171
1172
        return p0

1173
    def _get_data_dictionary_to_save(self):
1174
1175
        d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
                 ntemps=self.ntemps, theta_keys=self.theta_keys,
1176
                 theta_prior=self.theta_prior,
1177
                 log10beta_min=self.log10beta_min,
1178
1179
                 BSGL=self.BSGL, minStartTime=self.minStartTime,
                 maxStartTime=self.maxStartTime)
1180
1181
        return d

1182
    def _save_data(self, sampler, samples, lnprobs, lnlikes, all_lnlikelihood):
1183
        d = self._get_data_dictionary_to_save()
1184
1185
1186
        d['samples'] = samples
        d['lnprobs'] = lnprobs
        d['lnlikes'] = lnlikes
1187
        d['all_lnlikelihood'] = all_lnlikelihood
1188
1189
1190
1191
1192
1193
1194
1195

        if os.path.isfile(self.pickle_path):
            logging.info('Saving backup of {} as {}.old'.format(
                self.pickle_path, self.pickle_path))
            os.rename(self.pickle_path, self.pickle_path+".old")
        with open(self.pickle_path, "wb") as File:
            pickle.dump(d, File)

1196
    def get_saved_data_dictionary(self):
1197
        """ Returns dictionary of the data saved in the pickle """
1198
1199
1200
1201
        with open(self.pickle_path, "r") as File:
            d = pickle.load(File)
        return d

1202
    def _check_old_data_is_okay_to_use(self):
1203
1204
1205
1206
        if args.use_old_data:
            logging.info("Forcing use of old data")
            return True

1207
1208
1209
1210
        if os.path.isfile(self.pickle_path) is False:
            logging.info('No pickled data found')
            return False

1211
        if self.sftfilepattern is not None:
Gregory Ashton's avatar
Gregory Ashton committed
1212
            oldest_sft = min([os.path.getmtime(f) for f in
1213
                              self._get_list_of_matching_sfts()])
Gregory Ashton's avatar
Gregory Ashton committed
1214
1215
1216
            if os.path.getmtime(self.pickle_path) < oldest_sft:
                logging.info('Pickled data outdates sft files')
                return False
1217

1218
1219
        old_d = self.get_saved_data_dictionary().copy()
        new_d = self._get_data_dictionary_to_save().copy()
1220
1221
1222
1223

        old_d.pop('samples')
        old_d.pop('lnprobs')
        old_d.pop('lnlikes')
1224
        old_d.pop('all_lnlikelihood')
1225

1226
1227
1228
1229
1230
        for key in 'minStartTime', 'maxStartTime':
            if new_d[key] is None:
                new_d[key] = old_d[key]
                setattr(self, key, new_d[key])

1231
1232
1233
1234
1235
1236
        mod_keys = []
        for key in new_d.keys():
            if key in old_d:
                if new_d[key] != old_d[key]:
                    mod_keys.append((key, old_d[key], new_d[key]))
            else:
1237
                raise ValueError('Keys {} not in old dictionary'.format(key))
1238
1239
1240
1241
1242
1243
1244
1245
1246

        if len(mod_keys) == 0:
            return True
        else:
            logging.warning("Saved data differs from requested")
            logging.info("Differences found in following keys:")
            for key in mod_keys:
                if len(key) == 3:
                    if np.isscalar(key[1]) or key[0] == 'nsteps':
1247
                        logging.info("    {} : {} -> {}".format(*key))
1248
                    else:
1249
                        logging.info("    " + key[0])
1250
1251
1252
1253
1254
                else:
                    logging.info(key)
            return False

    def get_max_twoF(self, threshold=0.05):
1255
        """ Returns the max likelihood sample and the corresponding 2F value
1256
1257
1258
1259
1260
1261
1262

        Note: the sample is returned as a dictionary along with an estimate of
        the standard deviation calculated from the std of all samples with a
        twoF within `threshold` (relative) to the max twoF

        """
        if any(np.isposinf(self.lnlikes)):
1263
            logging.info('lnlike values contain positive infinite values')
1264
        if any(np.isneginf(self.lnlikes)):
1265
            logging.info('lnlike values contain negative infinite values')
1266
        if any(np.isnan(self.lnlikes)):
1267
            logging.info('lnlike values contain nan')
1268
1269
        idxs = np.isfinite(self.lnlikes)
        jmax = np.nanargmax(self.lnlikes[idxs])
1270
        maxlogl = self.lnlikes[jmax]
1271
        d = OrderedDict()
1272

1273
1274
        if self.BSGL:
            if hasattr(self, 'search') is False:
1275
                self._initiate_search_object()
1276
1277
1278
1279
1280
            p = self.samples[jmax]
            self.search.BSGL = False
            maxtwoF = self.logl(p, self.search)
            self.search.BSGL = self.BSGL
        else:
1281
            maxtwoF = (maxlogl - self.likelihoodcoef)*2
1282

Gregory Ashton's avatar
Gregory Ashton committed
1283
        repeats = []
1284
        for i, k in enumerate(self.theta_keys):
Gregory Ashton's avatar
Gregory Ashton committed
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
            if k in d and k not in repeats:
                d[k+'_0'] = d[k]  # relabel the old key
                d.pop(k)
                repeats.append(k)
            if k in repeats:
                k = k + '_0'
                count = 1
                while k in d:
                    k = k.replace('_{}'.format(count-1), '_{}'.format(count))
                    count += 1
1295
1296
1297
1298
1299
            d[k] = self.samples[jmax][i]
        return d, maxtwoF

    def get_median_stds(self):
        """ Returns a dict of the median and std of all production samples """
1300
        d = OrderedDict()
Gregory Ashton's avatar
Gregory Ashton committed
1301
        repeats = []
1302
        for s, k in zip(self.samples.T, self.theta_keys):
Gregory Ashton's avatar
Gregory Ashton committed
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
            if k in d and k not in repeats:
                d[k+'_0'] = d[k]  # relabel the old key
                d[k+'_0_std'] = d[k+'_std']
                d.pop(k)
                d.pop(k+'_std')
                repeats.append(k)
            if k in repeats:
                k = k + '_0'
                count = 1
                while k in d:
                    k = k.replace('_{}'.format(count-1), '_{}'.format(count))
                    count += 1

1316
1317
1318
1319
            d[k] = np.median(s)
            d[k+'_std'] = np.std(s)
        return d

1320
    def check_if_samples_are_railing(self, threshold=0.01):
1321
1322
1323
1324
1325
1326
        """ Returns a boolean estimate of if the samples are railing

        Parameters
        ----------
        threshold: float [0, 1]
            Fraction of the uniform prior to test (at upper and lower bound)
1327
1328
1329
1330
1331
1332

        Returns
        -------
        return_flag: bool
            IF true, the samples are railing

1333
        """
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
        return_flag = False
        for s, k in zip(self.samples.T, self.theta_keys):
            prior = self.theta_prior[k]
            if prior['type'] == 'unif':
                prior_range = prior['upper'] - prior['lower']
                edges = []
                fracs = []
                for l in ['lower', 'upper']:
                    bools = np.abs(s - prior[l])/prior_range < threshold
                    if np.any(bools):
                        edges.append(l)
                        fracs.append(str(100*float(np.sum(bools))/len(bools)))
                if len(edges) > 0:
                    logging.warning(
                        '{}% of the {} posterior is railing on the {} edges'
                        .format('% & '.join(fracs), k, ' & '.join(edges)))
                    return_flag = True
        return return_flag

1353
1354
1355
1356
    def write_par(self, method='med'):
        """ Writes a .par of the best-fit params with an estimated std """
        logging.info('Writing {}/{}.par using the {} method'.format(
            self.outdir, self.label, method))
1357
1358
1359
1360

        median_std_d = self.get_median_stds()
        max_twoF_d, max_twoF = self.get_max_twoF()

Gregory Ashton's avatar
Gregory Ashton committed
1361
        logging.info('Writing par file with max twoF = {}'.format(max_twoF))
1362
1363
1364
        filename = '{}/{}.par'.format(self.outdir, self.label)
        with open(filename, 'w+') as f:
            f.write('MaxtwoF = {}\n'.format(max_twoF))
Gregory Ashton's avatar
Gregory Ashton committed
1365
            f.write('tref = {}\n'.format(self.tref))
1366
1367
            if hasattr(self, 'theta0_index'):
                f.write('theta0_index = {}\n'.format(self.theta0_idx))
1368
            if method == 'med':
1369
1370
                for key, val in median_std_d.iteritems():
                    f.write('{} = {:1.16e}\n'.format(key, val))
1371
            if method == 'twoFmax':
1372
1373
1374
                for key, val in max_twoF_d.iteritems():
                    f.write('{} = {:1.16e}\n'.format(key, val))

1375
    def generate_loudest(self):
1376
        """ Use lalapps_ComputeFstatistic_v2 to produce a .loudest file """
1377
        self.write_par()
1378
        params = read_par(label=self.label, outdir=self.outdir)
1379
1380
1381
1382
1383
1384
1385
        for key in ['Alpha', 'Delta', 'F0', 'F1']:
            if key not in params:
                params[key] = self.theta_prior[key]
        cmd = ('lalapps_ComputeFstatistic_v2 -a {} -d {} -f {} -s {} -D "{}"'
               ' --refTime={} --outputLoudest="{}/{}.loudest" '
               '--minStartTime={} --maxStartTime={}').format(
                    params['Alpha'], params['Delta'], params['F0'],
1386
                    params['F1'], self.sftfilepattern, params['tref'],
1387
1388
1389
1390
                    self.outdir, self.label, self.minStartTime,
                    self.maxStartTime)
        subprocess.call([cmd], shell=True)

Gregory Ashton's avatar
Gregory Ashton committed
1391
    def write_prior_table(self):
1392
        """ Generate a .tex file of the prior """
Gregory Ashton's avatar
Gregory Ashton committed
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
        with open('{}/{}_prior.tex'.format(self.outdir, self.label), 'w') as f:
            f.write(r"\begin{tabular}{c l c} \hline" + '\n'
                    r"Parameter & & &  \\ \hhline{====}")

            for key, prior in self.theta_prior.iteritems():
                if type(prior) is dict:
                    Type = prior['type']
                    if Type == "unif":
                        a = prior['lower']
                        b = prior['upper']
                        line = r"{} & $\mathrm{{Unif}}$({}, {}) & {}\\"
                    elif Type == "norm":
                        a = prior['loc']
                        b = prior['scale']
                        line = r"{} & $\mathcal{{N}}$({}, {}) & {}\\"
                    elif Type == "halfnorm":
                        a = prior['loc']
                        b = prior['scale']
                        line = r"{} & $|\mathcal{{N}}$({}, {})| & {}\\"

                    u = self.unit_dictionary[key]
                    s = self.symbol_dictionary[key]
                    f.write("\n")
                    a = helper_functions.texify_float(a)
                    b = helper_functions.texify_float(b)
                    f.write(" " + line.format(s, a, b, u) + r" \\")
            f.write("\n\end{tabular}\n")

1421
    def print_summary(self):
1422
        """ Prints a summary of the max twoF found to the terminal """
Gregory Ashton's avatar
Gregory Ashton committed
1423
        max_twoFd, max_twoF = self.get_max_twoF()
1424
        median_std_d = self.get_median_stds()
Gregory Ashton's avatar
Gregory Ashton committed
1425
        logging.info('Summary:')
1426
        if hasattr(self, 'theta0_idx'):
Gregory Ashton's avatar
Gregory Ashton committed
1427
1428
            logging.info('theta0 index: {}'.format(self.theta0_idx))
        logging.info('Max twoF: {} with parameters:'.format(max_twoF))
Gregory Ashton's avatar
Gregory Ashton committed
1429
1430
        for k in np.sort(max_twoFd.keys()):
            print('  {:10s} = {:1.9e}'.format(k, max_twoFd[k]))
Gregory Ashton's avatar
Gregory Ashton committed
1431
        logging.info('Median +/- std for production values')
1432
        for k in np.sort(median_std_d.keys()):
1433
            if 'std' not in k:
Gregory Ashton's avatar
Gregory Ashton committed
1434
                logging.info('  {:10s} = {:1.9e} +/- {:1.9e}'.format(
1435
                    k, median_std_d[k], median_std_d[k+'_std']))
Gregory Ashton's avatar
Gregory Ashton committed
1436
        logging.info('\n')
1437

1438
    def _CF_twoFmax(self, theta, twoFmax, ntrials):
1439
1440
1441
1442
        Fmax = twoFmax/2.0
        return (np.exp(1j*theta*twoFmax)*ntrials/2.0
                * Fmax*np.exp(-Fmax)*(1-(1+Fmax)*np.exp(-Fmax))**(ntrials-1))

1443
    def _pdf_twoFhat(self, twoFhat, nglitch, ntrials, twoFmax=100, dtwoF=0.1):
1444
1445
1446
1447
1448
        if np.ndim(ntrials) == 0:
            ntrials = np.zeros(nglitch+1) + ntrials
        twoFmax_int = np.arange(0, twoFmax, dtwoF)
        theta_int = np.arange(-1/dtwoF, 1./dtwoF, 1./twoFmax)
        CF_twoFmax_theta = np.array(
1449
            [[np.trapz(self._CF_twoFmax(t, twoFmax_int, ntrial), twoFmax_int)
1450
1451
1452
1453
1454
1455
1456
1457
              for t in theta_int]
             for ntrial in ntrials])
        CF_twoFhat_theta = np.prod(CF_twoFmax_theta, axis=0)
        pdf = (1/(2*np.pi)) * np.array(
            [np.trapz(np.exp(-1j*theta_int*twoFhat_val)
             * CF_twoFhat_theta, theta_int) for twoFhat_val in twoFhat])
        return pdf.real

1458
    def _p_val_twoFhat(self, twoFhat, ntrials, twoFhatmax=500, Npoints=1000):
1459
        """ Caluculate the p-value for the given twoFhat in Gaussian noise
1460
1461
1462
1463
1464
1465
1466
1467
1468

        Parameters
        ----------
        twoFhat: float
            The observed twoFhat value
        ntrials: int, array of len Nglitch+1
            The number of trials for each glitch+1
        """
        twoFhats = np.linspace(twoFhat, twoFhatmax, Npoints)
1469
        pdf = self._pdf_twoFhat(twoFhats, self.nglitch, ntrials)
1470
1471
1472
1473
1474
1475
1476
1477
        return np.trapz(pdf, twoFhats)

    def get_p_value(self, delta_F0, time_trials=0):
        """ Get's the p-value for the maximum twoFhat value """
        d, max_twoF = self.get_max_twoF()
        if self.nglitch == 1:
            tglitches = [d['tglitch']]
        else:
1478
1479
            tglitches = [d['tglitch_{}'.format(i)]
                         for i in range(self.nglitch)]
1480
        tboundaries = [self.minStartTime] + tglitches + [self.maxStartTime]
1481
        deltaTs = np.diff(tboundaries)
1482
        ntrials = [time_trials + delta_F0 * dT for dT in deltaTs]
1483
        p_val = self._p_val_twoFhat(max_twoF, ntrials)
1484
        print('p-value = {}'.format(p_val))
1485
1486
        return p_val

Gregory Ashton's avatar
Gregory Ashton committed
1487
    def compute_evidence(self, write_to_file='Evidences.txt'):
1488
        """ Computes the evidence/marginal likelihood for the model """
1489
        betas = self.betas
1490
        mean_lnlikes = np.mean(np.mean(self.all_lnlikelihood, axis=1), axis=1)
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502

        mean_lnlikes = mean_lnlikes[::-1]
        betas = betas[::-1]

        fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(6, 8))

        if any(np.isinf(mean_lnlikes)):
            print("WARNING mean_lnlikes contains inf: recalculating without"
                  " the {} infs".format(len(betas[np.isinf(mean_lnlikes)])))
            idxs = np.isinf(mean_lnlikes)
            mean_lnlikes = mean_lnlikes[~idxs]
            betas = betas[~idxs]
1503
1504
1505
1506
1507
1508

        log10evidence = np.trapz(mean_lnlikes, betas)/np.log(10)
        z1 = np.trapz(mean_lnlikes, betas)
        z2 = np.trapz(mean_lnlikes[::-1][::2][::-1],
                      betas[::-1][::2][::-1])
        log10evidence_err = np.abs(z1 - z2) / np.log(10)
1509

Gregory Ashton's avatar
Gregory Ashton committed
1510
        logging.info("log10 evidence for {} = {} +/- {}".format(
1511
1512
              self.label, log10evidence, log10evidence_err))

Gregory Ashton's avatar
Gregory Ashton committed
1513
1514
1515
1516
1517
        if write_to_file:
            EvidenceDict = self.read_evidence_file_to_dict(write_to_file)
            EvidenceDict[self.label] = [log10evidence, log10evidence_err]
            self.write_evidence_file_from_dict(EvidenceDict, write_to_file)

1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
        ax1.semilogx(betas, mean_lnlikes, "-o")
        ax1.set_xlabel(r"$\beta$")
        ax1.set_ylabel(r"$\langle \log(\mathcal{L}) \rangle$")
        min_betas = []
        evidence = []
        for i in range(len(betas)/2):
            min_betas.append(betas[i])
            lnZ = np.trapz(mean_lnlikes[i:], betas[i:])
            evidence.append(lnZ/np.log(10))

        ax2.semilogx(min_betas, evidence, "-o")
        ax2.set_ylabel(r"$\int_{\beta_{\textrm{Min}}}^{\beta=1}" +
                       r"\langle \log(\mathcal{L})\rangle d\beta$", size=16)
        ax2.set_xlabel(r"$\beta_{\textrm{min}}$")
        plt.tight_layout()
        fig.savefig("{}/{}_beta_lnl.png".format(self.outdir, self.label))

Gregory Ashton's avatar
Gregory Ashton committed
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
    @staticmethod
    def read_evidence_file_to_dict(evidence_file_name='Evidences.txt'):
        EvidenceDict = OrderedDict()
        if os.path.isfile(evidence_file_name):
            with open(evidence_file_name, 'r') as f:
                for line in f:
                    key, log10evidence, log10evidence_err = line.split(' ')
                    EvidenceDict[key] = [
                        float(log10evidence), float(log10evidence_err)]
        return EvidenceDict

    def write_evidence_file_from_dict(self, EvidenceDict, evidence_file_name):
        with open(evidence_file_name, 'w+') as f:
            for key, val in EvidenceDict.iteritems():
                f.write('{} {} {}\n'.format(key, val[0], val[1]))

1551

Gregory Ashton's avatar
Gregory Ashton committed
1552
class MCMCGlitchSearch(MCMCSearch):
Gregory Ashton's avatar
Gregory Ashton committed
1553
    """MCMC search using the SemiCoherentGlitchSearch
1554

Gregory Ashton's avatar
Gregory Ashton committed
1555
1556
    See parent MCMCSearch for a list of all additional parameters, here we list
    only the additional init parameters of this class.
1557
1558
1559
1560
1561
1562
1563
1564
1565

    Parameters
    ----------
    nglitch: int
        The number of glitches to allow
    dtglitchmin: int
        The minimum duration (in seconds) of a segment between two glitches
        or a glitch and the start/end of the data
    theta0_idx, int
Gregory Ashton's avatar
Gregory Ashton committed
1566
        Index (zero-based) of which segment the theta refers to - useful
1567
1568
1569
1570
        if providing a tight prior on theta to allow the signal to jump
        too theta (and not just from)

    """
1571
1572

    symbol_dictionary = dict(
1573
1574
        F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', Alpha=r'$\alpha$',
        Delta='$\delta$', delta_F0='$\delta f$',
1575
1576
        delta_F1='$\delta \dot{f}$', tglitch='$t_\mathrm{glitch}$')
    unit_dictionary = dict(
1577
        F0='Hz', F1='Hz/s', F2='Hz/s$^2$', Alpha=r'rad', Delta='rad',
1578
        delta_F0='Hz', delta_F1='Hz/s', tglitch='s')
1579
    transform_dictionary = dict(
Gregory Ashton's avatar
Gregory Ashton committed
1580
1581
1582
1583
1584
1585
        tglitch={
            'multiplier': 1/86400.,
            'subtractor': 'minStartTime',
            'unit': 'day',
            'label': 'Glitch time \n days after minStartTime'}
            )
1586

Gregory Ashton's avatar
Gregory Ashton committed
1587
    @helper_functions.initializer
1588
1589
1590
    def __init__(self, theta_prior, tref, label, outdir='data',
                 minStartTime=None, maxStartTime=None, sftfilepattern=None,
                 detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1,
1591
                 log10beta_min=-5, theta_initial=None,
1592
                 rhohatmax=1000, binary=False, BSGL=False,
Gregory Ashton's avatar
Gregory Ashton committed
1593
1594
1595
                 SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
                 injectSources=None, assumeSqrtSX=None,
                 dtglitchmin=1*86400, theta0_idx=0, nglitch=1):
Gregory Ashton's avatar
Gregory Ashton committed
1596

Gregory Ashton's avatar
Gregory Ashton committed
1597
1598
        if os.path.isdir(outdir) is False:
            os.mkdir(outdir)
1599
        self._add_log_file()
Gregory Ashton's avatar
Gregory Ashton committed
1600
1601
        logging.info(('Set-up MCMC glitch search with {} glitches for model {}'
                      ' on data {}').format(self.nglitch, self.label,
1602
                                            self.sftfilepattern))
Gregory Ashton's avatar
Gregory Ashton committed
1603
        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
1604
        self._unpack_input_theta()
Gregory Ashton's avatar
Gregory Ashton committed
1605
        self.ndim = len(self.theta_keys)
1606
1607
        if self.log10beta_min:
            self.betas = np.logspace(0, self.log10beta_min, self.ntemps)
1608
1609
        else:
            self.betas = None
Gregory Ashton's avatar
Gregory Ashton committed
1610
1611
1612
        if args.clean and os.path.isfile(self.pickle_path):
            os.rename(self.pickle_path, self.pickle_path+".old")

1613
1614
        self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use()
        self._log_input()
1615
        self._set_likelihoodcoef()
Gregory Ashton's avatar
Gregory Ashton committed
1616

1617
1618
    def _set_likelihoodcoef(self):
        self.likelihoodcoef = (self.nglitch+1)*np.log(70./self.rhohatmax**4)
1619

1620
    def _initiate_search_object(self):
Gregory Ashton's avatar
Gregory Ashton committed
1621
        logging.info('Setting up search object')
1622
        self.search = core.SemiCoherentGlitchSearch(
1623
1624
1625
1626
1627
1628
            label=self.label, outdir=self.outdir,
            sftfilepattern=self.sftfilepattern, tref=self.tref,
            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
            detectors=self.detectors, BSGL=self.BSGL, nglitch=self.nglitch,
            theta0_idx=self.theta0_idx, injectSources=self.injectSources)
1629
1630
1631
1632
        if self.minStartTime is None:
            self.minStartTime = self.search.minStartTime
        if self.maxStartTime is None:
            self.maxStartTime = self.search.maxStartTime
Gregory Ashton's avatar
Gregory Ashton committed
1633
1634
1635

    def logp(self, theta_vals, theta_prior, theta_keys, search):
        if self.nglitch > 1:
1636
1637
            ts = ([self.minStartTime] + list(theta_vals[-self.nglitch:])
                  + [self.maxStartTime])
Gregory Ashton's avatar
Gregory Ashton committed
1638
1639
1640
1641
1642
            if np.array_equal(ts, np.sort(ts)) is False:
                return -np.inf
            if any(np.diff(ts) < self.dtglitchmin):
                return -np.inf

1643
        H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
Gregory Ashton's avatar
Gregory Ashton committed
1644
1645
1646
1647
             zip(theta_vals, theta_keys)]
        return np.sum(H)

    def logl(self, theta, search):
Gregory Ashton's avatar
Gregory Ashton committed
1648
        if self.nglitch > 1:
1649
            ts = ([self.minStartTime] + list(theta[-self.nglitch:])
1650
                  + [self.maxStartTime])
Gregory Ashton's avatar
Gregory Ashton committed
1651
1652
1653
            if np.array_equal(ts, np.sort(ts)) is False:
                return -np.inf

Gregory Ashton's avatar
Gregory Ashton committed
1654
1655
        for j, theta_i in enumerate(self.theta_idxs):
            self.fixed_theta[theta_i] = theta[j]
1656
1657
        twoF = search.get_semicoherent_nglitch_twoF(*self.fixed_theta)
        return twoF/2.0 + self.likelihoodcoef
Gregory Ashton's avatar
Gregory Ashton committed
1658

1659
    def _unpack_input_theta(self):
Gregory Ashton's avatar
Gregory Ashton committed
1660
1661
1662
        glitch_keys = ['delta_F0', 'delta_F1', 'tglitch']
        full_glitch_keys = list(np.array(
            [[gk]*self.nglitch for gk in glitch_keys]).flatten())
1663
1664
1665
1666

        if 'tglitch_0' in self.theta_prior:
            full_glitch_keys[-self.nglitch:] = [
                'tglitch_{}'.format(i) for i in range(self.nglitch)]
1667
1668
1669
1670
            full_glitch_keys[-2*self.nglitch:-1*self.nglitch] = [
                'delta_F1_{}'.format(i) for i in range(self.nglitch)]
            full_glitch_keys[-4*self.nglitch:-2*self.nglitch] = [
                'delta_F0_{}'.format(i) for i in range(self.nglitch)]
Gregory Ashton's avatar
Gregory Ashton committed
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
        full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']+full_glitch_keys
        full_theta_keys_copy = copy.copy(full_theta_keys)

        glitch_symbols = ['$\delta f$', '$\delta \dot{f}$', r'$t_{glitch}$']
        full_glitch_symbols = list(np.array(
            [[gs]*self.nglitch for gs in glitch_symbols]).flatten())
        full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
                               r'$\delta$'] + full_glitch_symbols)
        self.theta_keys = []
        fixed_theta_dict = {}
        for key, val in self.theta_prior.iteritems():
            if type(val) is dict:
                fixed_theta_dict[key] = 0
                if key in glitch_keys:
                    for i in range(self.nglitch):
                        self.theta_keys.append(key)
                else:
                    self.theta_keys.append(key)
            elif type(val) in [float, int, np.float64]:
                fixed_theta_dict[key] = val
            else:
                raise ValueError(
                    'Type {} of {} in theta not recognised'.format(
                        type(val), key))
            if key in glitch_keys:
                for i in range(self.nglitch):
                    full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
            else:
                full_theta_keys_copy.pop(full_theta_keys_copy.index(key))

        if len(full_theta_keys_copy) > 0:
            raise ValueError(('Input dictionary `theta` is missing the'
                              'following keys: {}').format(
                                  full_theta_keys_copy))

        self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
        self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
        self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]

        idxs = np.argsort(self.theta_idxs)
        self.theta_idxs = [self.theta_idxs[i] for i in idxs]
        self.theta_symbols = [self.theta_symbols[i] for i in idxs]
        self.theta_keys = [self.theta_keys[i] for i in idxs]

        # Correct for number of glitches in the idxs
        self.theta_idxs = np.array(self.theta_idxs)
        while np.sum(self.theta_idxs[:-1] == self.theta_idxs[1:]) > 0:
            for i, idx in enumerate(self.theta_idxs):
                if idx in self.theta_idxs[:i]:
                    self.theta_idxs[i] += 1

1722
    def _get_data_dictionary_to_save(self):
1723
1724
        d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
                 ntemps=self.ntemps, theta_keys=self.theta_keys,
1725
                 theta_prior=self.theta_prior,
1726
                 log10beta_min=self.log10beta_min,
1727
1728
1729
                 theta0_idx=self.theta0_idx, BSGL=self.BSGL)
        return d

1730
    def _apply_corrections_to_p0(self, p0):
Gregory Ashton's avatar
Gregory Ashton committed
1731
1732
1733
1734
1735
1736
        p0 = np.array(p0)
        if self.nglitch > 1:
            p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:],
                                               axis=2)
        return p0

Gregory Ashton's avatar
Gregory Ashton committed
1737
1738
1739
1740
1741
1742
1743
1744
    def plot_cumulative_max(self):

        fig, ax = plt.subplots()
        d, maxtwoF = self.get_max_twoF()
        for key, val in self.theta_prior.iteritems():
            if key not in d:
                d[key] = val

1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
        if self.nglitch > 1:
            delta_F0s = [d['delta_F0_{}'.format(i)] for i in
                         range(self.nglitch)]
            delta_F0s.insert(self.theta0_idx, 0)
            delta_F0s = np.array(delta_F0s)
            delta_F0s[:self.theta0_idx] *= -1
            tglitches = [d['tglitch_{}'.format(i)] for i in
                         range(self.nglitch)]
        elif self.nglitch == 1:
            delta_F0s = [d['delta_F0']]
            delta_F0s.insert(self.theta0_idx, 0)
            delta_F0s = np.array(delta_F0s)
            delta_F0s[:self.theta0_idx] *= -1
            tglitches = [d['tglitch']]
Gregory Ashton's avatar
Gregory Ashton committed
1759

1760
        tboundaries = [self.minStartTime] + tglitches + [self.maxStartTime]
Gregory Ashton's avatar
Gregory Ashton committed
1761
1762

        for j in range(self.nglitch+1):
1763
1764
            ts = tboundaries[j]
            te = tboundaries[j+1]
Gregory Ashton's avatar
Gregory Ashton committed
1765
1766
1767
1768
1769
1770
1771
1772
            if (te - ts)/86400 < 5:
                logging.info('Period too short to perform cumulative search')
                continue
            if j < self.theta0_idx:
                summed_deltaF0 = np.sum(delta_F0s[j:self.theta0_idx])
                F0_j = d['F0'] - summed_deltaF0
                taus, twoFs = self.search.calculate_twoF_cumulative(
                    F0_j, F1=d['F1'], F2=d['F2'], Alpha=d['Alpha'],
1773
                    Delta=d['Delta'], tstart=ts, tend=te)
Gregory Ashton's avatar
Gregory Ashton committed
1774
1775
1776
1777
1778
1779

            elif j >= self.theta0_idx:
                summed_deltaF0 = np.sum(delta_F0s[self.theta0_idx:j+1])
                F0_j = d['F0'] + summed_deltaF0
                taus, twoFs = self.search.calculate_twoF_cumulative(
                    F0_j, F1=d['F1'], F2=d['F2'], Alpha=d['Alpha'],
1780
                    Delta=d['Delta'], tstart=ts, tend=te)
Gregory Ashton's avatar
Gregory Ashton committed
1781
1782
1783
1784
1785
            ax.plot(ts+taus, twoFs)

        ax.set_xlabel('GPS time')
        fig.savefig('{}/{}_twoFcumulative.png'.format(self.outdir, self.label))

Gregory Ashton's avatar
Gregory Ashton committed
1786

1787
class MCMCSemiCoherentSearch(MCMCSearch):
Gregory Ashton's avatar
Gregory Ashton committed
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
    """ MCMC search for a signal using the semi-coherent ComputeFstat

    See parent MCMCSearch for a list of all additional parameters, here we list
    only the additional init parameters of this class.

    Parameters
    ----------
    nsegs: int
        The number of segments

    """

Gregory Ashton's avatar
Gregory Ashton committed
1800
    @helper_functions.initializer
1801
1802
1803
    def __init__(self, theta_prior, tref, label, outdir='data',
                 minStartTime=None, maxStartTime=None, sftfilepattern=None,
                 detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1,
1804
                 log10beta_min=-5, theta_initial=None,
1805
                 rhohatmax=1000, binary=False, BSGL=False,
Gregory Ashton's avatar
Gregory Ashton committed
1806
1807
1808
                 SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
                 injectSources=None, assumeSqrtSX=None,
                 nsegs=None):
1809
1810
1811

        if os.path.isdir(outdir) is False:
            os.mkdir(outdir)
1812
        self._add_log_file()
1813
1814
        logging.info(('Set-up MCMC semi-coherent search for model {} on data'
                      '{}').format(
1815
            self.label, self.sftfilepattern))
1816
        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
1817
        self._unpack_input_theta()
1818
        self.ndim = len(self.theta_keys)
1819
1820
        if self.log10beta_min:
            self.betas = np.logspace(0, self.log10beta_min, self.ntemps)
1821
1822
1823
1824
1825
        else:
            self.betas = None
        if args.clean and os.path.isfile(self.pickle_path):
            os.rename(self.pickle_path, self.pickle_path+".old")

1826
        self._log_input()
1827

1828
1829
1830
1831
1832
1833
1834
        if self.nsegs:
            self._set_likelihoodcoef()
        else:
            logging.info('Value `nsegs` not yet provided')

    def _set_likelihoodcoef(self):
        self.likelihoodcoef = self.nsegs * np.log(70./self.rhohatmax**4)
1835

1836
    def _get_data_dictionary_to_save(self):
1837
1838
        d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
                 ntemps=self.ntemps, theta_keys=self.theta_keys,
1839
                 theta_prior=self.theta_prior,
1840
                 log10beta_min=self.log10beta_min,
1841
1842
1843
                 BSGL=self.BSGL, nsegs=self.nsegs)
        return d

1844
    def _initiate_search_object(self):
1845
        logging.info('Setting up search object')
1846
        self.search = core.SemiCoherentSearch(
1847
            label=self.label, outdir=self.outdir, tref=self.tref,
1848
1849
            nsegs=self.nsegs, sftfilepattern=self.sftfilepattern,
            binary=self.binary, BSGL=self.BSGL, minStartTime=self.minStartTime,
1850
            maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
1851
            maxCoverFreq=self.maxCoverFreq, detectors=self.detectors,
1852
            injectSources=self.injectSources, assumeSqrtSX=self.assumeSqrtSX)
1853
1854
1855
1856
        if self.minStartTime is None:
            self.minStartTime = self.search.minStartTime
        if self.maxStartTime is None:
            self.maxStartTime = self.search.maxStartTime
1857
1858

    def logp(self, theta_vals, theta_prior, theta_keys, search):
1859
        H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
1860
1861
1862
1863
1864
1865
             zip(theta_vals, theta_keys)]
        return np.sum(H)

    def logl(self, theta, search):
        for j, theta_i in enumerate(self.theta_idxs):
            self.fixed_theta[theta_i] = theta[j]
1866
        twoF = search.get_semicoherent_twoF(
Gregory Ashton's avatar
Gregory Ashton committed
1867
            *self.fixed_theta)
1868
        return twoF/2.0 + self.likelihoodcoef
1869
1870


Gregory Ashton's avatar
Gregory Ashton committed
1871
class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
Gregory Ashton's avatar
Gregory Ashton committed
1872
1873
1874
1875
1876
1877
    """ A follow up procudure increasing the coherence time in a zoom

    See parent MCMCSemiCoherentSearch for a list of all additional parameters

    """

1878
    def _get_data_dictionary_to_save(self):
Gregory Ashton's avatar
Gregory Ashton committed
1879
1880
        d = dict(nwalkers=self.nwalkers, ntemps=self.ntemps,
                 theta_keys=self.theta_keys, theta_prior=self.theta_prior,
1881
                 log10beta_min=self.log10beta_min,
Gregory Ashton's avatar
Gregory Ashton committed
1882
1883
1884
                 BSGL=self.BSGL, run_setup=self.run_setup)
        return d

Gregory Ashton's avatar
Gregory Ashton committed
1885
1886
1887
1888
    def update_search_object(self):
        logging.info('Update search object')
        self.search.init_computefstatistic_single_point()

1889
1890
1891
1892
1893
1894
1895
1896
    def get_width_from_prior(self, prior, key):
        if prior[key]['type'] == 'unif':
            return prior[key]['upper'] - prior[key]['lower']

    def get_mid_from_prior(self, prior, key):
        if prior[key]['type'] == 'unif':
            return .5*(prior[key]['upper'] + prior[key]['lower'])

1897
1898
1899
1900
1901
    def read_setup_input_file(self, run_setup_input_file):
        with open(run_setup_input_file, 'r+') as f:
            d = pickle.load(f)
        return d

Gregory Ashton's avatar
Gregory Ashton committed
1902
    def write_setup_input_file(self, run_setup_input_file, NstarMax, Nsegs0,
1903
                               nsegs_vals, Nstar_vals, theta_prior):
Gregory Ashton's avatar
Gregory Ashton committed
1904
        d = dict(NstarMax=NstarMax, Nsegs0=Nsegs0, nsegs_vals=nsegs_vals,
1905
                 theta_prior=theta_prior, Nstar_vals=Nstar_vals)
1906
1907
1908
        with open(run_setup_input_file, 'w+') as f:
            pickle.dump(d, f)

1909
1910
    def check_old_run_setup(self, old_setup, **kwargs):
        try:
1911
            truths = [val == old_setup[key] for key, val in kwargs.iteritems()]
1912
1913
1914
1915
            if all(truths):
                return True
            else:
                logging.info(
1916
                    "Old setup doesn't match one of NstarMax, Nsegs0 or prior")
1917
1918
1919
        except KeyError as e:
            logging.info(
                'Error found when comparing with old setup: {}'.format(e))
1920
1921
            return False

Gregory Ashton's avatar
Gregory Ashton committed
1922
1923
    def init_run_setup(self, run_setup=None, NstarMax=1000, Nsegs0=None,
                       log_table=True, gen_tex_table=True):
1924
1925
1926

        if run_setup is None and Nsegs0 is None:
            raise ValueError(
Gregory Ashton's avatar
Gregory Ashton committed
1927
1928
                'You must either specify the run_setup, or Nsegs0 and NStarMax'
                ' from which the optimal run_setup can be estimated')
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
        if run_setup is None:
            logging.info('No run_setup provided')

            run_setup_input_file = '{}/{}_run_setup.p'.format(
                self.outdir, self.label)

            if os.path.isfile(run_setup_input_file):
                logging.info('Checking old setup input file {}'.format(
                    run_setup_input_file))
                old_setup = self.read_setup_input_file(run_setup_input_file)
Gregory Ashton's avatar
Gregory Ashton committed
1939
                if self.check_old_run_setup(old_setup, NstarMax=NstarMax,
1940
                                            Nsegs0=Nsegs0,
1941
                                            theta_prior=self.theta_prior):
Gregory Ashton's avatar
Gregory Ashton committed
1942
1943
1944
                    logging.info(
                        'Using old setup with NstarMax={}, Nsegs0={}'.format(
                            NstarMax, Nsegs0))
1945
                    nsegs_vals = old_setup['nsegs_vals']
1946
                    Nstar_vals = old_setup['Nstar_vals']
1947
                    generate_setup = False
1948
                else:
1949
1950
1951
                    generate_setup = True
            else:
                generate_setup = True
1952

1953
            if generate_setup:
1954
1955
1956
1957
1958
1959
1960
                nsegs_vals, Nstar_vals = (
                        optimal_setup_functions.get_optimal_setup(
                            NstarMax, Nsegs0, self.tref, self.minStartTime,
                            self.maxStartTime, self.theta_prior,
                            self.search.detector_names))
                self.write_setup_input_file(run_setup_input_file, NstarMax,
                                            Nsegs0, nsegs_vals, Nstar_vals,
1961
                                            self.theta_prior)
1962
1963
1964
1965
1966

            run_setup = [((self.nsteps[0], 0),  nsegs, False)
                         for nsegs in nsegs_vals[:-1]]
            run_setup.append(
                ((self.nsteps[0], self.nsteps[1]), nsegs_vals[-1], False))
1967
1968
1969

        else:
            logging.info('Calculating the number of templates for this setup')
1970
            Nstar_vals = []
1971
1972
1973
1974
1975
1976
1977
1978
            for i, rs in enumerate(run_setup):
                rs = list(rs)
                if len(rs) == 2:
                    rs.append(False)
                if np.shape(rs[0]) == ():
                    rs[0] = (rs[0], 0)
                run_setup[i] = rs

1979
                if args.no_template_counting:
1980
                    Nstar_vals.append([1, 1, 1])
1981
                else:
1982
                    Nstar = optimal_setup_functions.get_Nstar_estimate(
1983
                        rs[1], self.tref, self.minStartTime, self.maxStartTime,
1984
                        self.theta_prior, self.search.detector_names)
1985
                    Nstar_vals.append(Nstar)
1986
1987

        if log_table:
1988
            logging.info('Using run-setup as follows:')
1989
            logging.info(
1990
                'Stage | nburn | nprod | nsegs | Tcoh d | resetp0 | Nstar')
1991
            for i, rs in enumerate(run_setup):
1992
                Tcoh = (self.maxStartTime - self.minStartTime) / rs[1] / 86400
1993
                if Nstar_vals[i] is None:
1994
                    vtext = 'N/A'