diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index 066c70c986b282374a8364963b2bcbef036d1179..9c4adfd8fbfd45f32a708afc3bb3a6ea0776f635 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -100,6 +100,7 @@ class MCMCSearch(core.BaseSearchClass): delta='$\delta$') self.unit_dictionary = dict( F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad') + self.rescale_dictionary = {} self.log_input() @@ -415,6 +416,22 @@ class MCMCSearch(core.BaseSearchClass): self.lnlikes = lnlikes self.save_data(sampler, samples, lnprobs, lnlikes) + def scale_samples(self, samples, symbols, theta_keys): + for key in theta_keys: + if key in self.rescale_dictionary: + idx = theta_keys.index(key) + s = samples[:, idx] + if 'subtractor' in self.scale_dictionary[key]: + s = self.scale_dictionary[key]['subtractor'] - s + if 'multipler' in self.scale_dictionary[key]: + s *= self.scale_dictionary[key]['multipler'] + samples[:, idx] = s + + if 'label' in self.scale_dictionary['key']: + symbols[idx] = self.scale_dictionary[key]['label'] + + return samples, symbols + def plot_corner(self, figsize=(7, 7), tglitch_ratio=False, add_prior=False, nstds=None, label_offset=0.4, dpi=300, rc_context={}, **kwargs): @@ -435,6 +452,9 @@ class MCMCSearch(core.BaseSearchClass): samples_plt = copy.copy(self.samples) theta_symbols_plt = copy.copy(self.theta_symbols) + + samples_plt, theta_symbols_plt = self.scale_samples( + samples_plt, theta_symbols_plt, self.theta_keys) theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}') for s in theta_symbols_plt] @@ -1273,6 +1293,7 @@ class MCMCGlitchSearch(MCMCSearch): self.unit_dictionary = dict( F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad', delta_F0='Hz', delta_F1='Hz/s', tglitch='s') + self.rescale_dictionary = dict() self.log_input() def initiate_search_object(self): @@ -1829,6 +1850,14 @@ class MCMCTransientSearch(MCMCSearch): F0='Hz', F1='Hz/s', F2='Hz/s$^2$', alpha=r'rad', delta='rad', tstart='s', tend='s') + rescale_dictionary = dict( + transient_duration={'multiplier': 1/86400., + 'label': 'Transient duration'}, + transient_tstart={ + 'multiplier': 1/86400., + 'label': 'Transient start-time \n days after minStartTime'} + ) + def initiate_search_object(self): logging.info('Setting up search object') self.search = core.ComputeFstat(