Commit 5fe5f787 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Various improvements to the code

- Add peak detection utilities
- Move generate_loudest
- Add SSBprec pass through to MCMC searches
- Fix convergence stat (missing sqrt)
- Minor fixes to plot_walkers (put detection statistic behind walkers)
parent 55fc8889
......@@ -164,20 +164,6 @@ class BaseSearchClass(object):
self.thetas_at_tref = thetas
return thetas
def generate_loudest(self):
params = read_par(self.label, self.outdir)
for key in ['Alpha', 'Delta', 'F0', 'F1']:
if key not in params:
params[key] = self.theta_prior[key]
cmd = ('lalapps_ComputeFstatistic_v2 -a {} -d {} -f {} -s {} -D "{}"'
' --refTime={} --outputLoudest="{}/{}.loudest" '
'--minStartTime={} --maxStartTime={}').format(
params['Alpha'], params['Delta'], params['F0'],
params['F1'], self.sftfilepath, params['tref'],
self.outdir, self.label, self.minStartTime,
self.maxStartTime)
subprocess.call([cmd], shell=True)
def _get_list_of_matching_sfts(self):
matches = [glob.glob(p) for p in self.sftfilepath]
matches = [item for sublist in matches for item in sublist]
......@@ -277,7 +263,10 @@ class ComputeFstat(object):
logging.info('Initialising SFTCatalog')
constraints = lalpulsar.SFTConstraints()
if self.detectors:
constraints.detector = self.detectors
if ',' in self.detectors:
logging.info('Using all detector data')
else:
constraints.detector = self.detectors
if self.minStartTime:
constraints.minStartTime = lal.LIGOTimeGPS(self.minStartTime)
if self.maxStartTime:
......
......@@ -7,6 +7,7 @@ import sys
import argparse
import logging
import inspect
import peakutils
from functools import wraps
import matplotlib.pyplot as plt
......@@ -127,3 +128,26 @@ def initializer(func):
return wrapper
def get_peak_values(frequencies, twoF, threshold_2F, F0=None, F0range=None):
if F0:
cut_idxs = np.abs(frequencies - F0) < F0range
frequencies = frequencies[cut_idxs]
twoF = twoF[cut_idxs]
idxs = peakutils.indexes(twoF, thres=1.*threshold_2F/np.max(twoF))
F0maxs = frequencies[idxs]
twoFmaxs = twoF[idxs]
freq_err = frequencies[1] - frequencies[0]
return F0maxs, twoFmaxs, freq_err*np.ones(len(idxs))
def get_comb_values(F0, frequencies, twoF, period, N=4):
if period == 'sidereal':
period = 23*60*60 + 56*60 + 4.0616
elif period == 'terrestrial':
period = 86400
freq_err = frequencies[1] - frequencies[0]
comb_frequencies = [n*1/period for n in range(-N, N+1)]
comb_idxs = [np.argmin(np.abs(frequencies-F0-F)) for F in comb_frequencies]
return comb_frequencies, twoF[comb_idxs], freq_err*np.ones(len(comb_idxs))
......@@ -5,6 +5,7 @@ import os
import copy
import logging
from collections import OrderedDict
import subprocess
import numpy as np
import matplotlib
......@@ -14,7 +15,7 @@ import corner
import dill as pickle
import core
from core import tqdm, args, earth_ephem, sun_ephem
from core import tqdm, args, earth_ephem, sun_ephem, read_par
from optimal_setup_functions import get_V_estimate
from optimal_setup_functions import get_optimal_setup
import helper_functions
......@@ -38,7 +39,7 @@ class MCMCSearch(core.BaseSearchClass):
maxStartTime, sftfilepath=None, nsteps=[100, 100],
nwalkers=100, ntemps=1, log10temperature_min=-5,
theta_initial=None, scatter_val=1e-10, rhohatmax=1000,
binary=False, BSGL=False, minCoverFreq=None,
binary=False, BSGL=False, minCoverFreq=None, SSBprec=None,
maxCoverFreq=None, detectors=None, earth_ephem=None,
sun_ephem=None, injectSources=None, assumeSqrtSX=None):
"""
......@@ -133,7 +134,7 @@ class MCMCSearch(core.BaseSearchClass):
detectors=self.detectors, BSGL=self.BSGL, transient=False,
minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
binary=self.binary, injectSources=self.injectSources,
assumeSqrtSX=self.assumeSqrtSX)
assumeSqrtSX=self.assumeSqrtSX, SSBprec=self.SSBprec)
def logp(self, theta_vals, theta_prior, theta_keys, search):
H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
......@@ -292,7 +293,7 @@ class MCMCSearch(core.BaseSearchClass):
mean = np.mean(per_walker_mean, axis=0)
B = N / (M-1.) * np.sum((per_walker_mean-mean)**2, axis=0)
Vhat = (N-1)/N * W + (M+1)/(M*N) * B
c = Vhat/W
c = np.sqrt(Vhat/W)
self.convergence_diagnostic.append(c)
self.convergence_diagnosticx.append(i - self.convergence_length/2)
return c
......@@ -708,7 +709,7 @@ class MCMCSearch(core.BaseSearchClass):
else:
raise ValueError('Not implemented for prior type {}'.format(
prior_dict['type']))
priorln = ax.plot(x, prior, 'r', label='prior')
priorln = ax.plot(x, prior, 'C3', label='prior')
ax.set_xlabel(self.theta_symbols[i])
s = self.samples[:, i]
......@@ -845,7 +846,7 @@ class MCMCSearch(core.BaseSearchClass):
else:
raise ValueError("dist_type {} unknown".format(dist_type))
def _plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k",
def _plot_walkers(self, sampler, symbols=None, alpha=0.8, color="k",
temp=0, lw=0.1, nprod=0, add_det_stat_burnin=False,
fig=None, axes=None, xoffset=0, plot_det_stat=False,
context='ggplot', subtractions=None, labelpad=0.05):
......@@ -903,7 +904,7 @@ class MCMCSearch(core.BaseSearchClass):
if burnin_idx > 0:
axes[i].plot(xoffset+idxs[:convergence_idx+1],
cs[:convergence_idx+1]-subtractions[i],
color="r", alpha=alpha,
color="C3", alpha=alpha,
lw=lw)
axes[i].axvline(xoffset+convergence_idx,
color='k', ls='--', lw=0.25)
......@@ -920,11 +921,15 @@ class MCMCSearch(core.BaseSearchClass):
if hasattr(self, 'convergence_diagnostic'):
ax = axes[i].twinx()
axes[i].set_zorder(ax.get_zorder()+1)
axes[i].patch.set_visible(False)
c_x = np.array(self.convergence_diagnosticx)
c_y = np.array(self.convergence_diagnostic)
break_idx = np.argmin(np.abs(c_x - burnin_idx))
ax.plot(c_x[:break_idx], c_y[:break_idx, i], '-b')
ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-b')
ax.plot(c_x[:break_idx], c_y[:break_idx, i], '-C0',
zorder=-10)
ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-C0',
zorder=-10)
ax.set_ylabel('PSRF')
ax.ticklabel_format(useOffset=False)
ax.set_ylim(0.5, self.convergence_plot_upper_lim)
......@@ -933,7 +938,7 @@ class MCMCSearch(core.BaseSearchClass):
cs = chain[:, :, temp].T
if burnin_idx:
axes[0].plot(idxs[:burnin_idx], cs[:burnin_idx],
color="r", alpha=alpha, lw=lw)
color="C3", alpha=alpha, lw=lw)
axes[0].plot(idxs[burnin_idx:], cs[burnin_idx:], color="k",
alpha=alpha, lw=lw)
if symbols:
......@@ -950,7 +955,7 @@ class MCMCSearch(core.BaseSearchClass):
burn_in_vals = lnl[:, :burnin_idx].flatten()
try:
axes[-1].hist(burn_in_vals[~np.isnan(burn_in_vals)],
bins=50, histtype='step', color='r')
bins=50, histtype='step', color='C3')
except ValueError:
logging.info('Det. Stat. hist failed, most likely all '
'values where the same')
......@@ -1256,6 +1261,21 @@ class MCMCSearch(core.BaseSearchClass):
for key, val in max_twoF_d.iteritems():
f.write('{} = {:1.16e}\n'.format(key, val))
def generate_loudest(self):
self.write_par()
params = read_par(self.label, self.outdir)
for key in ['Alpha', 'Delta', 'F0', 'F1']:
if key not in params:
params[key] = self.theta_prior[key]
cmd = ('lalapps_ComputeFstatistic_v2 -a {} -d {} -f {} -s {} -D "{}"'
' --refTime={} --outputLoudest="{}/{}.loudest" '
'--minStartTime={} --maxStartTime={}').format(
params['Alpha'], params['Delta'], params['F0'],
params['F1'], self.sftfilepath, params['tref'],
self.outdir, self.label, self.minStartTime,
self.maxStartTime)
subprocess.call([cmd], shell=True)
def write_prior_table(self):
with open('{}/{}_prior.tex'.format(self.outdir, self.label), 'w') as f:
f.write(r"\begin{tabular}{c l c} \hline" + '\n'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment