Skip to content
Snippets Groups Projects
Commit a09a1bbe authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improvements to the way labels are handled

Adds proper unit support by default, and options ot customise through
the dictionaries
parent 98d4ebb4
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,15 @@ import helper_functions ...@@ -22,6 +22,15 @@ import helper_functions
class MCMCSearch(core.BaseSearchClass): class MCMCSearch(core.BaseSearchClass):
""" MCMC search using ComputeFstat""" """ MCMC search using ComputeFstat"""
symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
delta='$\delta$')
unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad')
rescale_dictionary = {}
@helper_functions.initializer @helper_functions.initializer
def __init__(self, label, outdir, theta_prior, tref, minStartTime, def __init__(self, label, outdir, theta_prior, tref, minStartTime,
maxStartTime, sftfilepath=None, nsteps=[100, 100], maxStartTime, sftfilepath=None, nsteps=[100, 100],
...@@ -95,13 +104,6 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -95,13 +104,6 @@ class MCMCSearch(core.BaseSearchClass):
if args.clean and os.path.isfile(self.pickle_path): if args.clean and os.path.isfile(self.pickle_path):
os.rename(self.pickle_path, self.pickle_path+".old") os.rename(self.pickle_path, self.pickle_path+".old")
self.symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
delta='$\delta$')
self.unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad')
self.rescale_dictionary = {}
self.log_input() self.log_input()
def log_input(self): def log_input(self):
...@@ -416,21 +418,90 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -416,21 +418,90 @@ class MCMCSearch(core.BaseSearchClass):
self.lnlikes = lnlikes self.lnlikes = lnlikes
self.save_data(sampler, samples, lnprobs, lnlikes) self.save_data(sampler, samples, lnprobs, lnlikes)
def scale_samples(self, samples, symbols, theta_keys): def get_rescale_multiplier_for_key(self, key):
""" Get the rescale multiplier from the rescale_dictionary
Can either be a float, a string (in which case it is interpretted as
a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent
in which case 0 is returned
"""
if key not in self.rescale_dictionary:
return 1
if 'multiplier' in self.rescale_dictionary[key]:
val = self.rescale_dictionary[key]['multiplier']
if type(val) == str:
if hasattr(self, val):
multiplier = getattr(
self, self.rescale_dictionary[key]['multiplier'])
else:
raise ValueError(
"multiplier {} not a class attribute".format(val))
else:
multiplier = val
else:
multiplier = 1
return multiplier
def get_rescale_subtractor_for_key(self, key):
""" Get the rescale subtractor from the rescale_dictionary
Can either be a float, a string (in which case it is interpretted as
a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent
in which case 0 is returned
"""
if key not in self.rescale_dictionary:
return 0
if 'subtractor' in self.rescale_dictionary[key]:
val = self.rescale_dictionary[key]['subtractor']
if type(val) == str:
if hasattr(self, val):
subtractor = getattr(
self, self.rescale_dictionary[key]['subtractor'])
else:
raise ValueError(
"subtractor {} not a class attribute".format(val))
else:
subtractor = val
else:
subtractor = 0
return subtractor
def scale_samples(self, samples, theta_keys):
""" Scale the samples using the rescale_dictionary """
for key in theta_keys: for key in theta_keys:
if key in self.rescale_dictionary: if key in self.rescale_dictionary:
idx = theta_keys.index(key) idx = theta_keys.index(key)
s = samples[:, idx] s = samples[:, idx]
if 'subtractor' in self.scale_dictionary[key]: subtractor = self.get_rescale_subtractor_for_key(key)
s = self.scale_dictionary[key]['subtractor'] - s s = s - subtractor
if 'multipler' in self.scale_dictionary[key]: multiplier = self.get_rescale_multiplier_for_key(key)
s *= self.scale_dictionary[key]['multipler'] s *= multiplier
samples[:, idx] = s samples[:, idx] = s
if 'label' in self.scale_dictionary['key']: return samples
symbols[idx] = self.scale_dictionary[key]['label']
return samples, symbols def get_labels(self):
""" Combine the units, symbols and rescaling to give labels """
labels = []
for key in self.theta_keys:
label = None
s = self.symbol_dictionary[key]
s.replace('_{glitch}', r'_\textrm{glitch}')
u = self.unit_dictionary[key]
if key in self.rescale_dictionary:
if 'symbol' in self.rescale_dictionary[key]:
s = self.rescale_dictionary[key]['symbol']
if 'label' in self.rescale_dictionary[key]:
label = self.rescale_dictionary[key]['label']
if 'unit' in self.rescale_dictionary[key]:
u = self.rescale_dictionary[key]['unit']
if label is None:
label = '{} \n [{}]'.format(s, u)
labels.append(label)
return labels
def plot_corner(self, figsize=(7, 7), tglitch_ratio=False, def plot_corner(self, figsize=(7, 7), tglitch_ratio=False,
add_prior=False, nstds=None, label_offset=0.4, add_prior=False, nstds=None, label_offset=0.4,
...@@ -451,12 +522,9 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -451,12 +522,9 @@ class MCMCSearch(core.BaseSearchClass):
figsize=figsize) figsize=figsize)
samples_plt = copy.copy(self.samples) samples_plt = copy.copy(self.samples)
theta_symbols_plt = copy.copy(self.theta_symbols) labels = self.get_labels()
samples_plt, theta_symbols_plt = self.scale_samples( samples_plt = self.scale_samples(samples_plt, self.theta_keys)
samples_plt, theta_symbols_plt, self.theta_keys)
theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}')
for s in theta_symbols_plt]
if tglitch_ratio: if tglitch_ratio:
for j, k in enumerate(self.theta_keys): for j, k in enumerate(self.theta_keys):
...@@ -465,7 +533,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -465,7 +533,7 @@ class MCMCSearch(core.BaseSearchClass):
samples_plt[:, j] = ( samples_plt[:, j] = (
s - self.minStartTime)/( s - self.minStartTime)/(
self.maxStartTime - self.minStartTime) self.maxStartTime - self.minStartTime)
theta_symbols_plt[j] = r'$R_{\textrm{glitch}}$' labels[j] = r'$R_{\textrm{glitch}}$'
if type(nstds) is int and 'range' not in kwargs: if type(nstds) is int and 'range' not in kwargs:
_range = [] _range = []
...@@ -477,7 +545,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -477,7 +545,7 @@ class MCMCSearch(core.BaseSearchClass):
_range = None _range = None
fig_triangle = corner.corner(samples_plt, fig_triangle = corner.corner(samples_plt,
labels=theta_symbols_plt, labels=labels,
fig=fig, fig=fig,
bins=50, bins=50,
max_n_ticks=4, max_n_ticks=4,
...@@ -515,9 +583,11 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -515,9 +583,11 @@ class MCMCSearch(core.BaseSearchClass):
s = samples[:, i] s = samples[:, i]
prior = self.generic_lnprior(**self.theta_prior[key]) prior = self.generic_lnprior(**self.theta_prior[key])
x = np.linspace(s.min(), s.max(), 100) x = np.linspace(s.min(), s.max(), 100)
multiplier = self.get_rescale_multiplier_for_key(key)
subtractor = self.get_rescale_subtractor_for_key(key)
ax2 = ax.twinx() ax2 = ax.twinx()
ax2.get_yaxis().set_visible(False) ax2.get_yaxis().set_visible(False)
ax2.plot(x, [prior(xi) for xi in x], '-r') ax2.plot(x, [(prior(xi)-subtractor)*multiplier for xi in x], '-r')
ax.set_xlim(xlim) ax.set_xlim(xlim)
def plot_prior_posterior(self, normal_stds=2): def plot_prior_posterior(self, normal_stds=2):
...@@ -1202,6 +1272,16 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -1202,6 +1272,16 @@ class MCMCSearch(core.BaseSearchClass):
class MCMCGlitchSearch(MCMCSearch): class MCMCGlitchSearch(MCMCSearch):
""" MCMC search using the SemiCoherentGlitchSearch """ """ MCMC search using the SemiCoherentGlitchSearch """
symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
delta='$\delta$', delta_F0='$\delta f$',
delta_F1='$\delta \dot{f}$', tglitch='$t_\mathrm{glitch}$')
unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
delta_F0='Hz', delta_F1='Hz/s', tglitch='s')
rescale_dictionary = dict()
@helper_functions.initializer @helper_functions.initializer
def __init__(self, label, outdir, sftfilepath, theta_prior, tref, def __init__(self, label, outdir, sftfilepath, theta_prior, tref,
minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100], minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100],
...@@ -1286,14 +1366,6 @@ class MCMCGlitchSearch(MCMCSearch): ...@@ -1286,14 +1366,6 @@ class MCMCGlitchSearch(MCMCSearch):
os.rename(self.pickle_path, self.pickle_path+".old") os.rename(self.pickle_path, self.pickle_path+".old")
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use() self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
self.symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
delta='$\delta$', delta_F0='$\delta f$',
delta_F1='$\delta \dot{f}$', tglitch='$t_\mathrm{glitch}$')
self.unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
delta_F0='Hz', delta_F1='Hz/s', tglitch='s')
self.rescale_dictionary = dict()
self.log_input() self.log_input()
def initiate_search_object(self): def initiate_search_object(self):
...@@ -1844,17 +1916,20 @@ class MCMCTransientSearch(MCMCSearch): ...@@ -1844,17 +1916,20 @@ class MCMCTransientSearch(MCMCSearch):
symbol_dictionary = dict( symbol_dictionary = dict(
F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$',
alpha=r'$\alpha$', delta='$\delta$', tstart='$t_\mathrm{start}$', alpha=r'$\alpha$', delta='$\delta$',
tend='$t_\mathrm{end}$') transient_tstart='$t_\mathrm{start}$', transient_duration='$\Delta T$')
unit_dictionary = dict( unit_dictionary = dict(
F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad', F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
tstart='s', tend='s') transient_tstart='s', transient_duration='s')
rescale_dictionary = dict( rescale_dictionary = dict(
transient_duration={'multiplier': 1/86400., transient_duration={'multiplier': 1/86400.,
'label': 'Transient duration'}, 'unit': 'day',
'symbol': 'Transient duration'},
transient_tstart={ transient_tstart={
'multiplier': 1/86400., 'multiplier': 1/86400.,
'subtractor': 'minStartTime',
'unit': 'day',
'label': 'Transient start-time \n days after minStartTime'} 'label': 'Transient start-time \n days after minStartTime'}
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment