Commit dc4b3b63 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adds ability to defin theta0_idx

Also various improvements to logging commands
parent 9543b886
...@@ -132,15 +132,23 @@ class BaseSearchClass(object): ...@@ -132,15 +132,23 @@ class BaseSearchClass(object):
m = self.shift_matrix(n, dT) m = self.shift_matrix(n, dT)
return np.dot(m, theta) return np.dot(m, theta)
def calculate_thetas(self, theta, delta_thetas, tbounds): def calculate_thetas(self, theta, delta_thetas, tbounds, theta0_idx=0):
""" Calculates the set of coefficients for the post-glitch signal """ """ Calculates the set of coefficients for the post-glitch signal """
thetas = [theta] thetas = [theta]
for i, dt in enumerate(delta_thetas): for i, dt in enumerate(delta_thetas):
pre_theta_at_ith_glitch = self.shift_coefficients( if i < theta0_idx:
thetas[i], tbounds[i+1] - self.tref) pre_theta_at_ith_glitch = self.shift_coefficients(
post_theta_at_ith_glitch = pre_theta_at_ith_glitch + dt thetas[0], tbounds[i+1] - self.tref)
thetas.append(self.shift_coefficients( post_theta_at_ith_glitch = pre_theta_at_ith_glitch - dt
post_theta_at_ith_glitch, self.tref - tbounds[i+1])) thetas.insert(0, self.shift_coefficients(
post_theta_at_ith_glitch, self.tref - tbounds[i+1]))
elif i >= theta0_idx:
pre_theta_at_ith_glitch = self.shift_coefficients(
thetas[i], tbounds[i+1] - self.tref)
post_theta_at_ith_glitch = pre_theta_at_ith_glitch + dt
thetas.append(self.shift_coefficients(
post_theta_at_ith_glitch, self.tref - tbounds[i+1]))
return thetas return thetas
...@@ -294,7 +302,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): ...@@ -294,7 +302,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
@initializer @initializer
def __init__(self, label, outdir, tref, tstart, tend, nglitch=0, def __init__(self, label, outdir, tref, tstart, tend, nglitch=0,
sftlabel=None, sftdir=None, minCoverFreq=None, sftlabel=None, sftdir=None, theta0_idx=0, minCoverFreq=None,
maxCoverFreq=None, detector=None, earth_ephem=None, maxCoverFreq=None, detector=None, earth_ephem=None,
sun_ephem=None): sun_ephem=None):
""" """
...@@ -310,6 +318,10 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): ...@@ -310,6 +318,10 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
sftlabel, sftdir: str sftlabel, sftdir: str
A label and directory in which to find the relevant sft file. If A label and directory in which to find the relevant sft file. If
None use label and outdir. None use label and outdir.
theta0_idx, int
Index (zero-based) of which segment the theta refers to - uyseful
if providing a tight prior on theta to allow the signal to jump
too theta (and not just from)
minCoverFreq, maxCoverFreq: float minCoverFreq, maxCoverFreq: float
The min and max cover frequency passed to CreateFstatInput, if The min and max cover frequency passed to CreateFstatInput, if
either is None the range of frequencies in the SFT less 1Hz is either is None the range of frequencies in the SFT less 1Hz is
...@@ -349,7 +361,8 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): ...@@ -349,7 +361,8 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
delta_thetas = np.atleast_2d( delta_thetas = np.atleast_2d(
np.array([delta_phi, delta_F0s, delta_F1s, delta_F2]).T) np.array([delta_phi, delta_F0s, delta_F1s, delta_F2]).T)
thetas = self.calculate_thetas(theta, delta_thetas, tboundaries) thetas = self.calculate_thetas(theta, delta_thetas, tboundaries,
theta0_idx=self.theta0_idx)
twoFSum = 0 twoFSum = 0
for i, theta_i_at_tref in enumerate(thetas): for i, theta_i_at_tref in enumerate(thetas):
...@@ -817,10 +830,24 @@ class MCMCSearch(BaseSearchClass): ...@@ -817,10 +830,24 @@ class MCMCSearch(BaseSearchClass):
ntemps_temp, self.nwalkers, self.ndim)[0, :, :] ntemps_temp, self.nwalkers, self.ndim)[0, :, :]
lnp = sampler.lnprobability[:, :, -1].reshape( lnp = sampler.lnprobability[:, :, -1].reshape(
self.ntemps, self.nwalkers)[0, :] self.ntemps, self.nwalkers)[0, :]
# General warnings about the state of lnp
if any(np.isnan(lnp)): if any(np.isnan(lnp)):
logging.warning("The sampler has produced nan's") logging.warning(
"Of {} lnprobs {} are nan".format(
len(lnp), np.sum(np.isnan(lnp))))
if any(np.isposinf(lnp)):
logging.warning(
"Of {} lnprobs {} are +np.inf".format(
len(lnp), np.sum(np.isposinf(lnp))))
if any(np.isneginf(lnp)):
logging.warning(
"Of {} lnprobs {} are -np.inf".format(
len(lnp), np.sum(np.isneginf(lnp))))
p = pF[np.nanargmax(lnp)] lnp_finite = copy.copy(lnp)
lnp_finite[np.isinf(lnp)] = np.nan
p = pF[np.nanargmax(lnp_finite)]
p0 = self.generate_scattered_p0(p) p0 = self.generate_scattered_p0(p)
return p0 return p0
...@@ -829,7 +856,8 @@ class MCMCSearch(BaseSearchClass): ...@@ -829,7 +856,8 @@ class MCMCSearch(BaseSearchClass):
d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers, d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
ntemps=self.ntemps, theta_keys=self.theta_keys, ntemps=self.ntemps, theta_keys=self.theta_keys,
theta_prior=self.theta_prior, scatter_val=self.scatter_val, theta_prior=self.theta_prior, scatter_val=self.scatter_val,
log10temperature_min=self.log10temperature_min) log10temperature_min=self.log10temperature_min,
theta0_idx=self.theta0_idx)
return d return d
def save_data(self, sampler, samples, lnprobs, lnlikes): def save_data(self, sampler, samples, lnprobs, lnlikes):
...@@ -888,7 +916,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -888,7 +916,7 @@ class MCMCSearch(BaseSearchClass):
if new_d[key] != old_d[key]: if new_d[key] != old_d[key]:
mod_keys.append((key, old_d[key], new_d[key])) mod_keys.append((key, old_d[key], new_d[key]))
else: else:
raise ValueError('Keys do not match') raise ValueError('Keys {} not in old dictionary'.format(key))
if len(mod_keys) == 0: if len(mod_keys) == 0:
return True return True
...@@ -954,6 +982,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -954,6 +982,7 @@ class MCMCSearch(BaseSearchClass):
filename = '{}/{}.par'.format(self.outdir, self.label) filename = '{}/{}.par'.format(self.outdir, self.label)
with open(filename, 'w+') as f: with open(filename, 'w+') as f:
f.write('MaxtwoF = {}\n'.format(max_twoF)) f.write('MaxtwoF = {}\n'.format(max_twoF))
f.write('theta0_index = {}\n'.format(self.theta0_idx))
if method == 'med': if method == 'med':
for key, val in median_std_d.iteritems(): for key, val in median_std_d.iteritems():
f.write('{} = {:1.16e}\n'.format(key, val)) f.write('{} = {:1.16e}\n'.format(key, val))
...@@ -964,6 +993,7 @@ class MCMCSearch(BaseSearchClass): ...@@ -964,6 +993,7 @@ class MCMCSearch(BaseSearchClass):
def print_summary(self): def print_summary(self):
d, max_twoF = self.get_max_twoF() d, max_twoF = self.get_max_twoF()
print('Max twoF: {}'.format(max_twoF)) print('Max twoF: {}'.format(max_twoF))
print('theta0 index: {}'.format(self.theta0_idx))
for k in np.sort(d.keys()): for k in np.sort(d.keys()):
if 'std' not in k: if 'std' not in k:
print('{:10s} = {:1.9e} +/- {:1.9e}'.format( print('{:10s} = {:1.9e} +/- {:1.9e}'.format(
...@@ -976,7 +1006,8 @@ class MCMCGlitchSearch(MCMCSearch): ...@@ -976,7 +1006,8 @@ class MCMCGlitchSearch(MCMCSearch):
def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref, def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref,
tstart, tend, nglitch=1, nsteps=[100, 100, 100], nwalkers=100, tstart, tend, nglitch=1, nsteps=[100, 100, 100], nwalkers=100,
ntemps=1, log10temperature_min=-5, theta_initial=None, ntemps=1, log10temperature_min=-5, theta_initial=None,
scatter_val=1e-4, dtglitchmin=1*86400, detector=None, scatter_val=1e-4, dtglitchmin=1*86400, theta0_idx=0,
detector=None,
minCoverFreq=None, maxCoverFreq=None, earth_ephem=None, minCoverFreq=None, maxCoverFreq=None, earth_ephem=None,
sun_ephem=None): sun_ephem=None):
""" """
...@@ -1016,6 +1047,10 @@ class MCMCGlitchSearch(MCMCSearch): ...@@ -1016,6 +1047,10 @@ class MCMCGlitchSearch(MCMCSearch):
log10temperature_min float < 0 log10temperature_min float < 0
The log_10(tmin) value, the set of betas passed to PTSampler are The log_10(tmin) value, the set of betas passed to PTSampler are
generated from np.logspace(0, log10temperature_min, ntemps). generated from np.logspace(0, log10temperature_min, ntemps).
theta0_idx, int
Index (zero-based) of which segment the theta refers to - uyseful
if providing a tight prior on theta to allow the signal to jump
too theta (and not just from)
detector: str detector: str
Two character reference to the data to use, specify None for no Two character reference to the data to use, specify None for no
contraint. contraint.
...@@ -1057,7 +1092,7 @@ class MCMCGlitchSearch(MCMCSearch): ...@@ -1057,7 +1092,7 @@ class MCMCGlitchSearch(MCMCSearch):
tend=self.tend, minCoverFreq=self.minCoverFreq, tend=self.tend, minCoverFreq=self.minCoverFreq,
maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem, maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
sun_ephem=self.sun_ephem, detector=self.detector, sun_ephem=self.sun_ephem, detector=self.detector,
nglitch=self.nglitch) nglitch=self.nglitch, theta0_idx=self.theta0_idx)
def logp(self, theta_vals, theta_prior, theta_keys, search): def logp(self, theta_vals, theta_prior, theta_keys, search):
if self.nglitch > 1: if self.nglitch > 1:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment