From a09a1bbeed4d75be406f722255ccf2127556c93c Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Mon, 10 Apr 2017 10:26:52 +0200
Subject: [PATCH] Improvements to the way labels are handled

Adds proper unit support by default, and options ot customise through
the dictionaries
---
 pyfstat/mcmc_based_searches.py | 145 +++++++++++++++++++++++++--------
 1 file changed, 110 insertions(+), 35 deletions(-)

diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index 9c4adfd..b5d0128 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -22,6 +22,15 @@ import helper_functions
 
 class MCMCSearch(core.BaseSearchClass):
     """ MCMC search using ComputeFstat"""
+
+    symbol_dictionary = dict(
+        F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
+        delta='$\delta$')
+    unit_dictionary = dict(
+        F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad')
+    rescale_dictionary = {}
+
+
     @helper_functions.initializer
     def __init__(self, label, outdir, theta_prior, tref, minStartTime,
                  maxStartTime, sftfilepath=None, nsteps=[100, 100],
@@ -95,13 +104,6 @@ class MCMCSearch(core.BaseSearchClass):
         if args.clean and os.path.isfile(self.pickle_path):
             os.rename(self.pickle_path, self.pickle_path+".old")
 
-        self.symbol_dictionary = dict(
-            F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
-            delta='$\delta$')
-        self.unit_dictionary = dict(
-            F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad')
-        self.rescale_dictionary = {}
-
         self.log_input()
 
     def log_input(self):
@@ -416,21 +418,90 @@ class MCMCSearch(core.BaseSearchClass):
         self.lnlikes = lnlikes
         self.save_data(sampler, samples, lnprobs, lnlikes)
 
-    def scale_samples(self, samples, symbols, theta_keys):
+    def get_rescale_multiplier_for_key(self, key):
+        """ Get the rescale multiplier from the rescale_dictionary
+
+        Can either be a float, a string (in which case it is interpretted as
+        a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent
+        in which case 0 is returned
+        """
+        if key not in self.rescale_dictionary:
+            return 1
+
+        if 'multiplier' in self.rescale_dictionary[key]:
+            val = self.rescale_dictionary[key]['multiplier']
+            if type(val) == str:
+                if hasattr(self, val):
+                    multiplier = getattr(
+                        self, self.rescale_dictionary[key]['multiplier'])
+                else:
+                    raise ValueError(
+                        "multiplier {} not a class attribute".format(val))
+            else:
+                multiplier = val
+        else:
+            multiplier = 1
+        return multiplier
+
+    def get_rescale_subtractor_for_key(self, key):
+        """ Get the rescale subtractor from the rescale_dictionary
+
+        Can either be a float, a string (in which case it is interpretted as
+        a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent
+        in which case 0 is returned
+        """
+        if key not in self.rescale_dictionary:
+            return 0
+
+        if 'subtractor' in self.rescale_dictionary[key]:
+            val = self.rescale_dictionary[key]['subtractor']
+            if type(val) == str:
+                if hasattr(self, val):
+                    subtractor = getattr(
+                        self, self.rescale_dictionary[key]['subtractor'])
+                else:
+                    raise ValueError(
+                        "subtractor {} not a class attribute".format(val))
+            else:
+                subtractor = val
+        else:
+            subtractor = 0
+        return subtractor
+
+    def scale_samples(self, samples, theta_keys):
+        """ Scale the samples using the rescale_dictionary """
         for key in theta_keys:
             if key in self.rescale_dictionary:
                 idx = theta_keys.index(key)
                 s = samples[:, idx]
-                if 'subtractor' in self.scale_dictionary[key]:
-                    s = self.scale_dictionary[key]['subtractor'] - s
-                if 'multipler' in self.scale_dictionary[key]:
-                    s *= self.scale_dictionary[key]['multipler']
+                subtractor = self.get_rescale_subtractor_for_key(key)
+                s = s - subtractor
+                multiplier = self.get_rescale_multiplier_for_key(key)
+                s *= multiplier
                 samples[:, idx] = s
 
-                if 'label' in self.scale_dictionary['key']:
-                    symbols[idx] = self.scale_dictionary[key]['label']
+        return samples
+
+    def get_labels(self):
+        """ Combine the units, symbols and rescaling to give labels """
 
-        return samples, symbols
+        labels = []
+        for key in self.theta_keys:
+            label = None
+            s = self.symbol_dictionary[key]
+            s.replace('_{glitch}', r'_\textrm{glitch}')
+            u = self.unit_dictionary[key]
+            if key in self.rescale_dictionary:
+                if 'symbol' in self.rescale_dictionary[key]:
+                    s = self.rescale_dictionary[key]['symbol']
+                if 'label' in self.rescale_dictionary[key]:
+                    label = self.rescale_dictionary[key]['label']
+                if 'unit' in self.rescale_dictionary[key]:
+                    u = self.rescale_dictionary[key]['unit']
+            if label is None:
+                label = '{} \n [{}]'.format(s, u)
+            labels.append(label)
+        return labels
 
     def plot_corner(self, figsize=(7, 7),  tglitch_ratio=False,
                     add_prior=False, nstds=None, label_offset=0.4,
@@ -451,12 +522,9 @@ class MCMCSearch(core.BaseSearchClass):
                                      figsize=figsize)
 
             samples_plt = copy.copy(self.samples)
-            theta_symbols_plt = copy.copy(self.theta_symbols)
+            labels = self.get_labels()
 
-            samples_plt, theta_symbols_plt = self.scale_samples(
-                samples_plt, theta_symbols_plt, self.theta_keys)
-            theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}')
-                                 for s in theta_symbols_plt]
+            samples_plt = self.scale_samples(samples_plt, self.theta_keys)
 
             if tglitch_ratio:
                 for j, k in enumerate(self.theta_keys):
@@ -465,7 +533,7 @@ class MCMCSearch(core.BaseSearchClass):
                         samples_plt[:, j] = (
                             s - self.minStartTime)/(
                                 self.maxStartTime - self.minStartTime)
-                        theta_symbols_plt[j] = r'$R_{\textrm{glitch}}$'
+                        labels[j] = r'$R_{\textrm{glitch}}$'
 
             if type(nstds) is int and 'range' not in kwargs:
                 _range = []
@@ -477,7 +545,7 @@ class MCMCSearch(core.BaseSearchClass):
                 _range = None
 
             fig_triangle = corner.corner(samples_plt,
-                                         labels=theta_symbols_plt,
+                                         labels=labels,
                                          fig=fig,
                                          bins=50,
                                          max_n_ticks=4,
@@ -515,9 +583,11 @@ class MCMCSearch(core.BaseSearchClass):
             s = samples[:, i]
             prior = self.generic_lnprior(**self.theta_prior[key])
             x = np.linspace(s.min(), s.max(), 100)
+            multiplier = self.get_rescale_multiplier_for_key(key)
+            subtractor = self.get_rescale_subtractor_for_key(key)
             ax2 = ax.twinx()
             ax2.get_yaxis().set_visible(False)
-            ax2.plot(x, [prior(xi) for xi in x], '-r')
+            ax2.plot(x, [(prior(xi)-subtractor)*multiplier for xi in x], '-r')
             ax.set_xlim(xlim)
 
     def plot_prior_posterior(self, normal_stds=2):
@@ -1202,6 +1272,16 @@ class MCMCSearch(core.BaseSearchClass):
 
 class MCMCGlitchSearch(MCMCSearch):
     """ MCMC search using the SemiCoherentGlitchSearch """
+
+    symbol_dictionary = dict(
+        F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
+        delta='$\delta$', delta_F0='$\delta f$',
+        delta_F1='$\delta \dot{f}$', tglitch='$t_\mathrm{glitch}$')
+    unit_dictionary = dict(
+        F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
+        delta_F0='Hz', delta_F1='Hz/s', tglitch='s')
+    rescale_dictionary = dict()
+
     @helper_functions.initializer
     def __init__(self, label, outdir, sftfilepath, theta_prior, tref,
                  minStartTime, maxStartTime, nglitch=1, nsteps=[100, 100],
@@ -1286,14 +1366,6 @@ class MCMCGlitchSearch(MCMCSearch):
             os.rename(self.pickle_path, self.pickle_path+".old")
 
         self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
-        self.symbol_dictionary = dict(
-            F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', alpha=r'$\alpha$',
-            delta='$\delta$', delta_F0='$\delta f$',
-            delta_F1='$\delta \dot{f}$', tglitch='$t_\mathrm{glitch}$')
-        self.unit_dictionary = dict(
-            F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
-            delta_F0='Hz', delta_F1='Hz/s', tglitch='s')
-        self.rescale_dictionary = dict()
         self.log_input()
 
     def initiate_search_object(self):
@@ -1844,17 +1916,20 @@ class MCMCTransientSearch(MCMCSearch):
 
     symbol_dictionary = dict(
         F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$',
-        alpha=r'$\alpha$', delta='$\delta$', tstart='$t_\mathrm{start}$',
-        tend='$t_\mathrm{end}$')
+        alpha=r'$\alpha$', delta='$\delta$',
+        transient_tstart='$t_\mathrm{start}$', transient_duration='$\Delta T$')
     unit_dictionary = dict(
         F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad',
-        tstart='s', tend='s')
+        transient_tstart='s', transient_duration='s')
 
     rescale_dictionary = dict(
         transient_duration={'multiplier': 1/86400.,
-                            'label': 'Transient duration'},
+                            'unit': 'day',
+                            'symbol': 'Transient duration'},
         transient_tstart={
             'multiplier': 1/86400.,
+            'subtractor': 'minStartTime',
+            'unit': 'day',
             'label': 'Transient start-time \n days after minStartTime'}
             )
 
-- 
GitLab