From a09a1bbeed4d75be406f722255ccf2127556c93c Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Mon, 10 Apr 2017 10:26:52 +0200 Subject: [PATCH] Improvements to the way labels are handled Adds proper unit support by default, and options ot customise through the dictionaries --- pyfstat/mcmc_based_searches.py | 145 +++++++++++++++++++++++++-------- 1 file changed, 110 insertions(+), 35 deletions(-) diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index 9c4adfd..b5d0128 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -22,6 +22,15 @@ import helper_functions class MCMCSearch(core.BaseSearchClass): """ MCMC search using ComputeFstat""" + + symbol_dictionary = dict( + F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$', + delta='$\delta$') + unit_dictionary = dict( + F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad') + rescale_dictionary = {} + + @helper_functions.initializer def __init__(self, label, outdir, theta_prior, tref, minStartTime, maxStartTime, sftfilepath=None, nsteps=[100, 100], @@ -95,13 +104,6 @@ class MCMCSearch(core.BaseSearchClass): if args.clean and os.path.isfile(self.pickle_path): os.rename(self.pickle_path, self.pickle_path+".old") - self.symbol_dictionary = dict( - F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$', - delta='$\delta$') - self.unit_dictionary = dict( - F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad') - self.rescale_dictionary = {} - self.log_input() def log_input(self): @@ -416,21 +418,90 @@ class MCMCSearch(core.BaseSearchClass): self.lnlikes = lnlikes self.save_data(sampler, samples, lnprobs, lnlikes) - def scale_samples(self, samples, symbols, theta_keys): + def get_rescale_multiplier_for_key(self, key): + """ Get the rescale multiplier from the rescale_dictionary + + Can either be a float, a string (in which case it is interpretted as + a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent + in which case 0 is returned + """ + if key not in self.rescale_dictionary: + return 1 + + if 'multiplier' in self.rescale_dictionary[key]: + val = self.rescale_dictionary[key]['multiplier'] + if type(val) == str: + if hasattr(self, val): + multiplier = getattr( + self, self.rescale_dictionary[key]['multiplier']) + else: + raise ValueError( + "multiplier {} not a class attribute".format(val)) + else: + multiplier = val + else: + multiplier = 1 + return multiplier + + def get_rescale_subtractor_for_key(self, key): + """ Get the rescale subtractor from the rescale_dictionary + + Can either be a float, a string (in which case it is interpretted as + a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent + in which case 0 is returned + """ + if key not in self.rescale_dictionary: + return 0 + + if 'subtractor' in self.rescale_dictionary[key]: + val = self.rescale_dictionary[key]['subtractor'] + if type(val) == str: + if hasattr(self, val): + subtractor = getattr( + self, self.rescale_dictionary[key]['subtractor']) + else: + raise ValueError( + "subtractor {} not a class attribute".format(val)) + else: + subtractor = val + else: + subtractor = 0 + return subtractor + + def scale_samples(self, samples, theta_keys): + """ Scale the samples using the rescale_dictionary """ for key in theta_keys: if key in self.rescale_dictionary: idx = theta_keys.index(key) s = samples[:, idx] - if 'subtractor' in self.scale_dictionary[key]: - s = self.scale_dictionary[key]['subtractor'] - s - if 'multipler' in self.scale_dictionary[key]: - s *= self.scale_dictionary[key]['multipler'] + subtractor = self.get_rescale_subtractor_for_key(key) + s = s - subtractor + multiplier = self.get_rescale_multiplier_for_key(key) + s *= multiplier samples[:, idx] = s - if 'label' in self.scale_dictionary['key']: - symbols[idx] = self.scale_dictionary[key]['label'] + return samples + + def get_labels(self): + """ Combine the units, symbols and rescaling to give labels """ - return samples, symbols + labels = [] + for key in self.theta_keys: + label = None + s = self.symbol_dictionary[key] + s.replace('_{glitch}', r'_\textrm{glitch}') + u = self.unit_dictionary[key] + if key in self.rescale_dictionary: + if 'symbol' in self.rescale_dictionary[key]: + s = self.rescale_dictionary[key]['symbol'] + if 'label' in self.rescale_dictionary[key]: + label = self.rescale_dictionary[key]['label'] + if 'unit' in self.rescale_dictionary[key]: + u = self.rescale_dictionary[key]['unit'] + if label is None: + label = '{} \n [{}]'.format(s, u) + labels.append(label) + return labels def plot_corner(self, figsize=(7, 7), tglitch_ratio=False, add_prior=False, nstds=None, label_offset=0.4, @@ -451,12 +522,9 @@ class MCMCSearch(core.BaseSearchClass): figsize=figsize) samples_plt = copy.copy(self.samples) - theta_symbols_plt = copy.copy(self.theta_symbols) + labels = self.get_labels() - samples_plt, theta_symbols_plt = self.scale_samples( - samples_plt, theta_symbols_plt, self.theta_keys) - theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}') - for s in theta_symbols_plt] + samples_plt = self.scale_samples(samples_plt, self.theta_keys) if tglitch_ratio: for j, k in enumerate(self.theta_keys): @@ -465,7 +533,7 @@ class MCMCSearch(core.BaseSearchClass): samples_plt[:, j] = ( s - self.minStartTime)/( self.maxStartTime - self.minStartTime) - theta_symbols_plt[j] = r'$R_{\textrm{glitch}}$' + labels[j] = r'$R_{\textrm{glitch}}$' if type(nstds) is int and 'range' not in kwargs: _range = [] @@ -477,7 +545,7 @@ class MCMCSearch(core.BaseSearchClass): _range = None fig_triangle = corner.corner(samples_plt, - labels=theta_symbols_plt, + labels=labels, fig=fig, bins=50, max_n_ticks=4, @@ -515,9 +583,11 @@ class MCMCSearch(core.BaseSearchClass): s = samples[:, i] prior = self.generic_lnprior(**self.theta_prior[key]) x = np.linspace(s.min(), s.max(), 100) + multiplier = self.get_rescale_multiplier_for_key(key) + subtractor = self.get_rescale_subtractor_for_key(key) ax2 = ax.twinx() ax2.get_yaxis().set_visible(False) - ax2.plot(x, [prior(xi) for xi in x], '-r') + ax2.plot(x, [(prior(xi)-subtractor)*multiplier for xi in x], '-r') ax.set_xlim(xlim) def plot_prior_posterior(self, normal_stds=2): @@ -1202,6 +1272,16 @@ class MCMCSearch(core.BaseSearchClass): class MCMCGlitchSearch(MCMCSearch): """ MCMC search using the SemiCoherentGlitchSearch """ + + symbol_dictionary = dict( + F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$', + delta='$\delta$', delta_F0='$\delta f$', + delta_F1='$\delta \dot{f}$', tglitch='$t_\mathrm{glitch}$') + unit_dictionary = dict( + F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad', + delta_F0='Hz', delta_F1='Hz/s', tglitch='s') + rescale_dictionary = dict() + @helper_functions.initializer def __init__(self, label, outdir, sftfilepath, theta_prior, tref, minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100], @@ -1286,14 +1366,6 @@ class MCMCGlitchSearch(MCMCSearch): os.rename(self.pickle_path, self.pickle_path+".old") self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use() - self.symbol_dictionary = dict( - F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$', - delta='$\delta$', delta_F0='$\delta f$', - delta_F1='$\delta \dot{f}$', tglitch='$t_\mathrm{glitch}$') - self.unit_dictionary = dict( - F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad', - delta_F0='Hz', delta_F1='Hz/s', tglitch='s') - self.rescale_dictionary = dict() self.log_input() def initiate_search_object(self): @@ -1844,17 +1916,20 @@ class MCMCTransientSearch(MCMCSearch): symbol_dictionary = dict( F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', - alpha=r'$\alpha$', delta='$\delta$', tstart='$t_\mathrm{start}$', - tend='$t_\mathrm{end}$') + alpha=r'$\alpha$', delta='$\delta$', + transient_tstart='$t_\mathrm{start}$', transient_duration='$\Delta T$') unit_dictionary = dict( F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad', - tstart='s', tend='s') + transient_tstart='s', transient_duration='s') rescale_dictionary = dict( transient_duration={'multiplier': 1/86400., - 'label': 'Transient duration'}, + 'unit': 'day', + 'symbol': 'Transient duration'}, transient_tstart={ 'multiplier': 1/86400., + 'subtractor': 'minStartTime', + 'unit': 'day', 'label': 'Transient start-time \n days after minStartTime'} ) -- GitLab