Commit ea5474da authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Various small fixes to MCMC scripts

- Fixes undefined theta_values -> theta (missed as Ng > 1 isn't tested)
- Import from core rathern than explicitly
- Various PEP8 fixes
parent c888a265
......@@ -13,14 +13,14 @@ import emcee
import corner
import dill as pickle
from core import BaseSearchClass, ComputeFstat, SemiCoherentSearch
from optimal_setup_functions import get_V_estimate
import core
from core import tqdm, args, earth_ephem, sun_ephem
from optimal_setup_functions import get_V_estimate
from optimal_setup_functions import get_optimal_setup
import helper_functions
class MCMCSearch(BaseSearchClass):
class MCMCSearch(core.BaseSearchClass):
""" MCMC search using ComputeFstat"""
@helper_functions.initializer
def __init__(self, label, outdir, theta_prior, tref, minStartTime,
......@@ -108,7 +108,7 @@ class MCMCSearch(BaseSearchClass):
def initiate_search_object(self):
logging.info('Setting up search object')
self.search = ComputeFstat(
self.search = core.ComputeFstat(
tref=self.tref, sftfilepath=self.sftfilepath,
minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
......@@ -1084,7 +1084,8 @@ class MCMCSearch(BaseSearchClass):
if self.nglitch == 1:
tglitches = [d['tglitch']]
else:
tglitches = [d['tglitch_{}'.format(i)] for i in range(self.nglitch)]
tglitches = [d['tglitch_{}'.format(i)]
for i in range(self.nglitch)]
tboundaries = [self.minStartTime] + tglitches + [self.maxStartTime]
deltaTs = np.diff(tboundaries)
ntrials = [time_trials + delta_F0 * dT for dT in deltaTs]
......@@ -1234,7 +1235,7 @@ class MCMCGlitchSearch(MCMCSearch):
def initiate_search_object(self):
logging.info('Setting up search object')
self.search = SemiCoherentGlitchSearch(
self.search = core.SemiCoherentGlitchSearch(
label=self.label, outdir=self.outdir, sftfilepath=self.sftfilepath,
tref=self.tref, minStartTime=self.minStartTime,
maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
......@@ -1257,7 +1258,7 @@ class MCMCGlitchSearch(MCMCSearch):
def logl(self, theta, search):
if self.nglitch > 1:
ts = ([self.minStartTime] + list(theta_vals[-self.nglitch:])
ts = ([self.minStartTime] + list(theta[-self.nglitch:])
+ [self.maxStartTime])
if np.array_equal(ts, np.sort(ts)) is False:
return -np.inf
......@@ -1399,12 +1400,12 @@ class MCMCSemiCoherentSearch(MCMCSearch):
""" MCMC search for a signal using the semi-coherent ComputeFstat """
@helper_functions.initializer
def __init__(self, label, outdir, theta_prior, tref, sftfilepath=None,
nsegs=None, nsteps=[100, 100, 100], nwalkers=100, binary=False,
ntemps=1, log10temperature_min=-5, theta_initial=None,
scatter_val=1e-10, detectors=None, BSGL=False,
minStartTime=None, maxStartTime=None, minCoverFreq=None,
maxCoverFreq=None, earth_ephem=None, sun_ephem=None,
injectSources=None, assumeSqrtSX=None):
nsegs=None, nsteps=[100, 100, 100], nwalkers=100,
binary=False, ntemps=1, log10temperature_min=-5,
theta_initial=None, scatter_val=1e-10, detectors=None,
BSGL=False, minStartTime=None, maxStartTime=None,
minCoverFreq=None, maxCoverFreq=None, earth_ephem=None,
sun_ephem=None, injectSources=None, assumeSqrtSX=None):
"""
"""
......@@ -1442,7 +1443,7 @@ class MCMCSemiCoherentSearch(MCMCSearch):
def initiate_search_object(self):
logging.info('Setting up search object')
self.search = SemiCoherentSearch(
self.search = core.SemiCoherentSearch(
label=self.label, outdir=self.outdir, tref=self.tref,
nsegs=self.nsegs, sftfilepath=self.sftfilepath, binary=self.binary,
BSGL=self.BSGL, minStartTime=self.minStartTime,
......@@ -1564,7 +1565,8 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
V_vals = old_setup['V_vals']
generate_setup = False
else:
logging.info('Old setup does not match requested R, Nsegs0')
logging.info(
'Old setup does not match requested R, Nsegs0')
generate_setup = True
else:
generate_setup = True
......@@ -1779,7 +1781,7 @@ class MCMCTransientSearch(MCMCSearch):
def initiate_search_object(self):
logging.info('Setting up search object')
self.search = ComputeFstat(
self.search = core.ComputeFstat(
tref=self.tref, sftfilepath=self.sftfilepath,
minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
......@@ -1840,6 +1842,3 @@ class MCMCTransientSearch(MCMCSearch):
self.theta_idxs = [self.theta_idxs[i] for i in idxs]
self.theta_symbols = [self.theta_symbols[i] for i in idxs]
self.theta_keys = [self.theta_keys[i] for i in idxs]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment