Skip to content
Snippets Groups Projects
Commit a09a1bbe authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improvements to the way labels are handled

Adds proper unit support by default, and options ot customise through
the dictionaries
parent 98d4ebb4
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,15 @@ import helper_functions
class MCMCSearch(core.BaseSearchClass):
""" MCMC search using ComputeFstat"""
symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
delta='$\delta$')
unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad')
rescale_dictionary = {}
@helper_functions.initializer
def __init__(self, label, outdir, theta_prior, tref, minStartTime,
maxStartTime, sftfilepath=None, nsteps=[100, 100],
......@@ -95,13 +104,6 @@ class MCMCSearch(core.BaseSearchClass):
if args.clean and os.path.isfile(self.pickle_path):
os.rename(self.pickle_path, self.pickle_path+".old")
self.symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
delta='$\delta$')
self.unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad')
self.rescale_dictionary = {}
self.log_input()
def log_input(self):
......@@ -416,21 +418,90 @@ class MCMCSearch(core.BaseSearchClass):
self.lnlikes = lnlikes
self.save_data(sampler, samples, lnprobs, lnlikes)
def scale_samples(self, samples, symbols, theta_keys):
def get_rescale_multiplier_for_key(self, key):
""" Get the rescale multiplier from the rescale_dictionary
Can either be a float, a string (in which case it is interpretted as
a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent
in which case 0 is returned
"""
if key not in self.rescale_dictionary:
return 1
if 'multiplier' in self.rescale_dictionary[key]:
val = self.rescale_dictionary[key]['multiplier']
if type(val) == str:
if hasattr(self, val):
multiplier = getattr(
self, self.rescale_dictionary[key]['multiplier'])
else:
raise ValueError(
"multiplier {} not a class attribute".format(val))
else:
multiplier = val
else:
multiplier = 1
return multiplier
def get_rescale_subtractor_for_key(self, key):
""" Get the rescale subtractor from the rescale_dictionary
Can either be a float, a string (in which case it is interpretted as
a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent
in which case 0 is returned
"""
if key not in self.rescale_dictionary:
return 0
if 'subtractor' in self.rescale_dictionary[key]:
val = self.rescale_dictionary[key]['subtractor']
if type(val) == str:
if hasattr(self, val):
subtractor = getattr(
self, self.rescale_dictionary[key]['subtractor'])
else:
raise ValueError(
"subtractor {} not a class attribute".format(val))
else:
subtractor = val
else:
subtractor = 0
return subtractor
def scale_samples(self, samples, theta_keys):
""" Scale the samples using the rescale_dictionary """
for key in theta_keys:
if key in self.rescale_dictionary:
idx = theta_keys.index(key)
s = samples[:, idx]
if 'subtractor' in self.scale_dictionary[key]:
s = self.scale_dictionary[key]['subtractor'] - s
if 'multipler' in self.scale_dictionary[key]:
s *= self.scale_dictionary[key]['multipler']
subtractor = self.get_rescale_subtractor_for_key(key)
s = s - subtractor
multiplier = self.get_rescale_multiplier_for_key(key)
s *= multiplier
samples[:, idx] = s
if 'label' in self.scale_dictionary['key']:
symbols[idx] = self.scale_dictionary[key]['label']
return samples
return samples, symbols
def get_labels(self):
""" Combine the units, symbols and rescaling to give labels """
labels = []
for key in self.theta_keys:
label = None
s = self.symbol_dictionary[key]
s.replace('_{glitch}', r'_\textrm{glitch}')
u = self.unit_dictionary[key]
if key in self.rescale_dictionary:
if 'symbol' in self.rescale_dictionary[key]:
s = self.rescale_dictionary[key]['symbol']
if 'label' in self.rescale_dictionary[key]:
label = self.rescale_dictionary[key]['label']
if 'unit' in self.rescale_dictionary[key]:
u = self.rescale_dictionary[key]['unit']
if label is None:
label = '{} \n [{}]'.format(s, u)
labels.append(label)
return labels
def plot_corner(self, figsize=(7, 7), tglitch_ratio=False,
add_prior=False, nstds=None, label_offset=0.4,
......@@ -451,12 +522,9 @@ class MCMCSearch(core.BaseSearchClass):
figsize=figsize)
samples_plt = copy.copy(self.samples)
theta_symbols_plt = copy.copy(self.theta_symbols)
labels = self.get_labels()
samples_plt, theta_symbols_plt = self.scale_samples(
samples_plt, theta_symbols_plt, self.theta_keys)
theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}')
for s in theta_symbols_plt]
samples_plt = self.scale_samples(samples_plt, self.theta_keys)
if tglitch_ratio:
for j, k in enumerate(self.theta_keys):
......@@ -465,7 +533,7 @@ class MCMCSearch(core.BaseSearchClass):
samples_plt[:, j] = (
s - self.minStartTime)/(
self.maxStartTime - self.minStartTime)
theta_symbols_plt[j] = r'$R_{\textrm{glitch}}$'
labels[j] = r'$R_{\textrm{glitch}}$'
if type(nstds) is int and 'range' not in kwargs:
_range = []
......@@ -477,7 +545,7 @@ class MCMCSearch(core.BaseSearchClass):
_range = None
fig_triangle = corner.corner(samples_plt,
labels=theta_symbols_plt,
labels=labels,
fig=fig,
bins=50,
max_n_ticks=4,
......@@ -515,9 +583,11 @@ class MCMCSearch(core.BaseSearchClass):
s = samples[:, i]
prior = self.generic_lnprior(**self.theta_prior[key])
x = np.linspace(s.min(), s.max(), 100)
multiplier = self.get_rescale_multiplier_for_key(key)
subtractor = self.get_rescale_subtractor_for_key(key)
ax2 = ax.twinx()
ax2.get_yaxis().set_visible(False)
ax2.plot(x, [prior(xi) for xi in x], '-r')
ax2.plot(x, [(prior(xi)-subtractor)*multiplier for xi in x], '-r')
ax.set_xlim(xlim)
def plot_prior_posterior(self, normal_stds=2):
......@@ -1202,6 +1272,16 @@ class MCMCSearch(core.BaseSearchClass):
class MCMCGlitchSearch(MCMCSearch):
""" MCMC search using the SemiCoherentGlitchSearch """
symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
delta='$\delta$', delta_F0='$\delta f$',
delta_F1='$\delta \dot{f}$', tglitch='$t_\mathrm{glitch}$')
unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
delta_F0='Hz', delta_F1='Hz/s', tglitch='s')
rescale_dictionary = dict()
@helper_functions.initializer
def __init__(self, label, outdir, sftfilepath, theta_prior, tref,
minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100],
......@@ -1286,14 +1366,6 @@ class MCMCGlitchSearch(MCMCSearch):
os.rename(self.pickle_path, self.pickle_path+".old")
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
self.symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
delta='$\delta$', delta_F0='$\delta f$',
delta_F1='$\delta \dot{f}$', tglitch='$t_\mathrm{glitch}$')
self.unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
delta_F0='Hz', delta_F1='Hz/s', tglitch='s')
self.rescale_dictionary = dict()
self.log_input()
def initiate_search_object(self):
......@@ -1844,17 +1916,20 @@ class MCMCTransientSearch(MCMCSearch):
symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$',
alpha=r'$\alpha$', delta='$\delta$', tstart='$t_\mathrm{start}$',
tend='$t_\mathrm{end}$')
alpha=r'$\alpha$', delta='$\delta$',
transient_tstart='$t_\mathrm{start}$', transient_duration='$\Delta T$')
unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
tstart='s', tend='s')
transient_tstart='s', transient_duration='s')
rescale_dictionary = dict(
transient_duration={'multiplier': 1/86400.,
'label': 'Transient duration'},
'unit': 'day',
'symbol': 'Transient duration'},
transient_tstart={
'multiplier': 1/86400.,
'subtractor': 'minStartTime',
'unit': 'day',
'label': 'Transient start-time \n days after minStartTime'}
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment