Commit a09a1bbe authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improvements to the way labels are handled

Adds proper unit support by default, and options ot customise through
the dictionaries
parent 98d4ebb4
......@@ -22,6 +22,15 @@ import helper_functions
class MCMCSearch(core.BaseSearchClass):
""" MCMC search using ComputeFstat"""
symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
delta='$\delta$')
unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad')
rescale_dictionary = {}
@helper_functions.initializer
def __init__(self, label, outdir, theta_prior, tref, minStartTime,
maxStartTime, sftfilepath=None, nsteps=[100, 100],
......@@ -95,13 +104,6 @@ class MCMCSearch(core.BaseSearchClass):
if args.clean and os.path.isfile(self.pickle_path):
os.rename(self.pickle_path, self.pickle_path+".old")
self.symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
delta='$\delta$')
self.unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad')
self.rescale_dictionary = {}
self.log_input()
def log_input(self):
......@@ -416,21 +418,90 @@ class MCMCSearch(core.BaseSearchClass):
self.lnlikes = lnlikes
self.save_data(sampler, samples, lnprobs, lnlikes)
def scale_samples(self, samples, symbols, theta_keys):
def get_rescale_multiplier_for_key(self, key):
""" Get the rescale multiplier from the rescale_dictionary
Can either be a float, a string (in which case it is interpretted as
a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent
in which case 0 is returned
"""
if key not in self.rescale_dictionary:
return 1
if 'multiplier' in self.rescale_dictionary[key]:
val = self.rescale_dictionary[key]['multiplier']
if type(val) == str:
if hasattr(self, val):
multiplier = getattr(
self, self.rescale_dictionary[key]['multiplier'])
else:
raise ValueError(
"multiplier {} not a class attribute".format(val))
else:
multiplier = val
else:
multiplier = 1
return multiplier
def get_rescale_subtractor_for_key(self, key):
""" Get the rescale subtractor from the rescale_dictionary
Can either be a float, a string (in which case it is interpretted as
a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent
in which case 0 is returned
"""
if key not in self.rescale_dictionary:
return 0
if 'subtractor' in self.rescale_dictionary[key]:
val = self.rescale_dictionary[key]['subtractor']
if type(val) == str:
if hasattr(self, val):
subtractor = getattr(
self, self.rescale_dictionary[key]['subtractor'])
else:
raise ValueError(
"subtractor {} not a class attribute".format(val))
else:
subtractor = val
else:
subtractor = 0
return subtractor
def scale_samples(self, samples, theta_keys):
""" Scale the samples using the rescale_dictionary """
for key in theta_keys:
if key in self.rescale_dictionary:
idx = theta_keys.index(key)
s = samples[:, idx]
if 'subtractor' in self.scale_dictionary[key]:
s = self.scale_dictionary[key]['subtractor'] - s
if 'multipler' in self.scale_dictionary[key]:
s *= self.scale_dictionary[key]['multipler']
subtractor = self.get_rescale_subtractor_for_key(key)
s = s - subtractor
multiplier = self.get_rescale_multiplier_for_key(key)
s *= multiplier
samples[:, idx] = s
if 'label' in self.scale_dictionary['key']:
symbols[idx] = self.scale_dictionary[key]['label']
return samples
def get_labels(self):
""" Combine the units, symbols and rescaling to give labels """
return samples, symbols
labels = []
for key in self.theta_keys:
label = None
s = self.symbol_dictionary[key]
s.replace('_{glitch}', r'_\textrm{glitch}')
u = self.unit_dictionary[key]
if key in self.rescale_dictionary:
if 'symbol' in self.rescale_dictionary[key]:
s = self.rescale_dictionary[key]['symbol']
if 'label' in self.rescale_dictionary[key]:
label = self.rescale_dictionary[key]['label']
if 'unit' in self.rescale_dictionary[key]:
u = self.rescale_dictionary[key]['unit']
if label is None:
label = '{} \n [{}]'.format(s, u)
labels.append(label)
return labels
def plot_corner(self, figsize=(7, 7), tglitch_ratio=False,
add_prior=False, nstds=None, label_offset=0.4,
......@@ -451,12 +522,9 @@ class MCMCSearch(core.BaseSearchClass):
figsize=figsize)
samples_plt = copy.copy(self.samples)
theta_symbols_plt = copy.copy(self.theta_symbols)
labels = self.get_labels()
samples_plt, theta_symbols_plt = self.scale_samples(
samples_plt, theta_symbols_plt, self.theta_keys)
theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}')
for s in theta_symbols_plt]
samples_plt = self.scale_samples(samples_plt, self.theta_keys)
if tglitch_ratio:
for j, k in enumerate(self.theta_keys):
......@@ -465,7 +533,7 @@ class MCMCSearch(core.BaseSearchClass):
samples_plt[:, j] = (
s - self.minStartTime)/(
self.maxStartTime - self.minStartTime)
theta_symbols_plt[j] = r'$R_{\textrm{glitch}}$'
labels[j] = r'$R_{\textrm{glitch}}$'
if type(nstds) is int and 'range' not in kwargs:
_range = []
......@@ -477,7 +545,7 @@ class MCMCSearch(core.BaseSearchClass):
_range = None
fig_triangle = corner.corner(samples_plt,
labels=theta_symbols_plt,
labels=labels,
fig=fig,
bins=50,
max_n_ticks=4,
......@@ -515,9 +583,11 @@ class MCMCSearch(core.BaseSearchClass):
s = samples[:, i]
prior = self.generic_lnprior(**self.theta_prior[key])
x = np.linspace(s.min(), s.max(), 100)
multiplier = self.get_rescale_multiplier_for_key(key)
subtractor = self.get_rescale_subtractor_for_key(key)
ax2 = ax.twinx()
ax2.get_yaxis().set_visible(False)
ax2.plot(x, [prior(xi) for xi in x], '-r')
ax2.plot(x, [(prior(xi)-subtractor)*multiplier for xi in x], '-r')
ax.set_xlim(xlim)
def plot_prior_posterior(self, normal_stds=2):
......@@ -1202,6 +1272,16 @@ class MCMCSearch(core.BaseSearchClass):
class MCMCGlitchSearch(MCMCSearch):
""" MCMC search using the SemiCoherentGlitchSearch """
symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
delta='$\delta$', delta_F0='$\delta f$',
delta_F1='$\delta \dot{f}$', tglitch='$t_\mathrm{glitch}$')
unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
delta_F0='Hz', delta_F1='Hz/s', tglitch='s')
rescale_dictionary = dict()
@helper_functions.initializer
def __init__(self, label, outdir, sftfilepath, theta_prior, tref,
minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100],
......@@ -1286,14 +1366,6 @@ class MCMCGlitchSearch(MCMCSearch):
os.rename(self.pickle_path, self.pickle_path+".old")
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
self.symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
delta='$\delta$', delta_F0='$\delta f$',
delta_F1='$\delta \dot{f}$', tglitch='$t_\mathrm{glitch}$')
self.unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
delta_F0='Hz', delta_F1='Hz/s', tglitch='s')
self.rescale_dictionary = dict()
self.log_input()
def initiate_search_object(self):
......@@ -1844,17 +1916,20 @@ class MCMCTransientSearch(MCMCSearch):
symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$',
alpha=r'$\alpha$', delta='$\delta$', tstart='$t_\mathrm{start}$',
tend='$t_\mathrm{end}$')
alpha=r'$\alpha$', delta='$\delta$',
transient_tstart='$t_\mathrm{start}$', transient_duration='$\Delta T$')
unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
tstart='s', tend='s')
transient_tstart='s', transient_duration='s')
rescale_dictionary = dict(
transient_duration={'multiplier': 1/86400.,
'label': 'Transient duration'},
'unit': 'day',
'symbol': 'Transient duration'},
transient_tstart={
'multiplier': 1/86400.,
'subtractor': 'minStartTime',
'unit': 'day',
'label': 'Transient start-time \n days after minStartTime'}
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment