diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index dda1de89fa592f2ee908871cc8580fd9deb245a5..a95424db50f0637d1c228ddc3d34a30c4bc4caad 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -22,7 +22,64 @@ import pyfstat.helper_functions as helper_functions class MCMCSearch(core.BaseSearchClass): - """ MCMC search using ComputeFstat""" + """ MCMC search using ComputeFstat + + Parameters + ---------- + label, outdir: str + A label and directory to read/write data from/to + sftfilepattern: str + Pattern to match SFTs using wildcards (*?) and ranges [0-9]; + mutiple patterns can be given separated by colons. + theta_prior: dict + Dictionary of priors and fixed values for the search parameters. + For each parameters (key of the dict), if it is to be held fixed + the value should be the constant float, if it is be searched, the + value should be a dictionary of the prior. + theta_initial: dict, array, (None) + Either a dictionary of distribution about which to distribute the + initial walkers about, an array (from which the walkers will be + scattered by scatter_val, or None in which case the prior is used. + tref, minStartTime, maxStartTime: int + GPS seconds of the reference time, start time and end time + nsteps: list (m,) + List specifying the number of steps to take, the last two entries + give the nburn and nprod of the 'production' run, all entries + before are for iterative initialisation steps (usually just one) + e.g. [1000, 1000, 500]. + nwalkers, ntemps: int, + The number of walkers and temperates to use in the parallel + tempered PTSampler. + log10temperature_min float < 0 + The log_10(tmin) value, the set of betas passed to PTSampler are + generated from np.logspace(0, log10temperature_min, ntemps). + rhohatmax: float + Upper bound for the SNR scale parameter (required to normalise the + Bayes factor) - this needs to be carefully set when using the + evidence. + binary: Bool + If true, search over binary parameters + detectors: str + Two character reference to the data to use, specify None for no + contraint. + minCoverFreq, maxCoverFreq: float + Minimum and maximum instantaneous frequency which will be covered + over the SFT time span as passed to CreateFstatInput + + Attributes + ---------- + symbol_dictionary: dict + Key, val pairs of the parameters (i.e. `F0`, `F1`), to Latex math + symbols for plots + unit_dictionary: dict + Key, val pairs of the parameters (i.e. `F0`, `F1`), and the + units (i.e. `Hz`) + transform_dictionary: dict + Key, val pairs of the parameters (i.e. `F0`, `F1`), where the key is + itself a dictionary which can item `multiplier`, `subtractor`, or + `unit` by which to transform by and update the units. + + """ symbol_dictionary = dict( F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', Alpha=r'$\alpha$', @@ -31,7 +88,7 @@ class MCMCSearch(core.BaseSearchClass): unit_dictionary = dict( F0='Hz', F1='Hz/s', F2='Hz/s$^2$', Alpha=r'rad', Delta='rad', asini='', period='s', ecc='', tp='', argp='') - rescale_dictionary = {} + transform_dictionary = {} @helper_functions.initializer def __init__(self, label, outdir, theta_prior, tref, minStartTime, @@ -41,50 +98,6 @@ class MCMCSearch(core.BaseSearchClass): binary=False, BSGL=False, minCoverFreq=None, SSBprec=None, maxCoverFreq=None, detectors=None, injectSources=None, assumeSqrtSX=None): - """ - Parameters - ---------- - label, outdir: str - A label and directory to read/write data from/to - sftfilepattern: str - Pattern to match SFTs using wildcards (*?) and ranges [0-9]; - mutiple patterns can be given separated by colons. - theta_prior: dict - Dictionary of priors and fixed values for the search parameters. - For each parameters (key of the dict), if it is to be held fixed - the value should be the constant float, if it is be searched, the - value should be a dictionary of the prior. - theta_initial: dict, array, (None) - Either a dictionary of distribution about which to distribute the - initial walkers about, an array (from which the walkers will be - scattered by scatter_val, or None in which case the prior is used. - tref, minStartTime, maxStartTime: int - GPS seconds of the reference time, start time and end time - nsteps: list (m,) - List specifying the number of steps to take, the last two entries - give the nburn and nprod of the 'production' run, all entries - before are for iterative initialisation steps (usually just one) - e.g. [1000, 1000, 500]. - nwalkers, ntemps: int, - The number of walkers and temperates to use in the parallel - tempered PTSampler. - log10temperature_min float < 0 - The log_10(tmin) value, the set of betas passed to PTSampler are - generated from np.logspace(0, log10temperature_min, ntemps). - rhohatmax: float - Upper bound for the SNR scale parameter (required to normalise the - Bayes factor) - this needs to be carefully set when using the - evidence. - binary: Bool - If true, search over binary parameters - detectors: str - Two character reference to the data to use, specify None for no - contraint. - minCoverFreq, maxCoverFreq: float - Minimum and maximum instantaneous frequency which will be covered - over the SFT time span as passed to CreateFstatInput - - """ if os.path.isdir(outdir) is False: os.mkdir(outdir) @@ -108,12 +121,11 @@ class MCMCSearch(core.BaseSearchClass): os.rename(self.pickle_path, self.pickle_path+".old") self._set_likelihoodcoef() + self._log_input() def _set_likelihoodcoef(self): self.likelihoodcoef = np.log(70./self.rhohatmax**4) - self._log_input() - def _log_input(self): logging.info('theta_prior = {}'.format(self.theta_prior)) logging.info('nwalkers={}'.format(self.nwalkers)) @@ -224,15 +236,9 @@ class MCMCSearch(core.BaseSearchClass): return p0 - def _OLD_run_sampler_with_progress_bar(self, sampler, ns, p0): - for result in tqdm(sampler.sample(p0, iterations=ns), total=ns): - pass - return sampler - def setup_burnin_convergence_testing( self, n=10, test_type='autocorr', windowed=False, **kwargs): - """ - If called, convergence testing is used during the MCMC simulation + """ Set up convergence testing during the MCMC simulation Parameters ---------- @@ -245,6 +251,10 @@ class MCMCSearch(core.BaseSearchClass): windowed: bool If True, only calculate the convergence test in a window of length `n` + **kwargs: + Passed to either `_test_autocorr_convergence()` or + `_test_GR_convergence()` depending on `test_type`. + """ logging.info('Setting up convergence testing') self.convergence_n = n @@ -254,13 +264,13 @@ class MCMCSearch(core.BaseSearchClass): self.convergence_diagnostic = [] self.convergence_diagnosticx = [] if test_type in ['autocorr']: - self._get_convergence_test = self.test_autocorr_convergence + self._get_convergence_test = self._test_autocorr_convergence elif test_type in ['GR']: - self._get_convergence_test = self.test_GR_convergence + self._get_convergence_test = self._test_GR_convergence else: raise ValueError('test_type {} not understood'.format(test_type)) - def test_autocorr_convergence(self, i, sampler, test=True, n_cut=5): + def _test_autocorr_convergence(self, i, sampler, test=True, n_cut=5): try: acors = np.zeros((self.ntemps, self.ndim)) for temp in range(self.ntemps): @@ -284,7 +294,7 @@ class MCMCSearch(core.BaseSearchClass): if test: return i > n_cut * np.max(c) - def test_GR_convergence(self, i, sampler, test=True, R=1.1): + def _test_GR_convergence(self, i, sampler, test=True, R=1.1): if self.convergence_windowed: s = sampler.chain[0, :, i-self.convergence_n+1:i+1, :] else: @@ -403,6 +413,11 @@ class MCMCSearch(core.BaseSearchClass): **kwargs: Passed to _plot_walkers to control the figures + Returns + ------- + sampler: emcee.ptsampler.PTSampler + The emcee ptsampler object + """ self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use() @@ -473,21 +488,21 @@ class MCMCSearch(core.BaseSearchClass): return sampler def _get_rescale_multiplier_for_key(self, key): - """ Get the rescale multiplier from the rescale_dictionary + """ Get the rescale multiplier from the transform_dictionary Can either be a float, a string (in which case it is interpretted as a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent in which case 0 is returned """ - if key not in self.rescale_dictionary: + if key not in self.transform_dictionary: return 1 - if 'multiplier' in self.rescale_dictionary[key]: - val = self.rescale_dictionary[key]['multiplier'] + if 'multiplier' in self.transform_dictionary[key]: + val = self.transform_dictionary[key]['multiplier'] if type(val) == str: if hasattr(self, val): multiplier = getattr( - self, self.rescale_dictionary[key]['multiplier']) + self, self.transform_dictionary[key]['multiplier']) else: raise ValueError( "multiplier {} not a class attribute".format(val)) @@ -498,21 +513,21 @@ class MCMCSearch(core.BaseSearchClass): return multiplier def _get_rescale_subtractor_for_key(self, key): - """ Get the rescale subtractor from the rescale_dictionary + """ Get the rescale subtractor from the transform_dictionary Can either be a float, a string (in which case it is interpretted as a attribute of the MCMCSearch class, e.g. minStartTime, or non-existent in which case 0 is returned """ - if key not in self.rescale_dictionary: + if key not in self.transform_dictionary: return 0 - if 'subtractor' in self.rescale_dictionary[key]: - val = self.rescale_dictionary[key]['subtractor'] + if 'subtractor' in self.transform_dictionary[key]: + val = self.transform_dictionary[key]['subtractor'] if type(val) == str: if hasattr(self, val): subtractor = getattr( - self, self.rescale_dictionary[key]['subtractor']) + self, self.transform_dictionary[key]['subtractor']) else: raise ValueError( "subtractor {} not a class attribute".format(val)) @@ -523,9 +538,9 @@ class MCMCSearch(core.BaseSearchClass): return subtractor def _scale_samples(self, samples, theta_keys): - """ Scale the samples using the rescale_dictionary """ + """ Scale the samples using the transform_dictionary """ for key in theta_keys: - if key in self.rescale_dictionary: + if key in self.transform_dictionary: idx = theta_keys.index(key) s = samples[:, idx] subtractor = self._get_rescale_subtractor_for_key(key) @@ -545,13 +560,13 @@ class MCMCSearch(core.BaseSearchClass): s = self.symbol_dictionary[key] s.replace('_{glitch}', r'_\textrm{glitch}') u = self.unit_dictionary[key] - if key in self.rescale_dictionary: - if 'symbol' in self.rescale_dictionary[key]: - s = self.rescale_dictionary[key]['symbol'] - if 'label' in self.rescale_dictionary[key]: - label = self.rescale_dictionary[key]['label'] - if 'unit' in self.rescale_dictionary[key]: - u = self.rescale_dictionary[key]['unit'] + if key in self.transform_dictionary: + if 'symbol' in self.transform_dictionary[key]: + s = self.transform_dictionary[key]['symbol'] + if 'label' in self.transform_dictionary[key]: + label = self.transform_dictionary[key]['label'] + if 'unit' in self.transform_dictionary[key]: + u = self.transform_dictionary[key]['unit'] if label is None: label = '{} \n [{}]'.format(s, u) labels.append(label) @@ -592,8 +607,13 @@ class MCMCSearch(core.BaseSearchClass): namely (ndim, ndim) save_fig: bool If true, save the figure, else return the fig, axes + **kwargs: + Passed to corner.corner - Note: kwargs are passed on to corner.corner + Returns + ------- + fig, axes: + The matplotlib figure and axes, only returned if save_fig = False """ @@ -798,7 +818,7 @@ class MCMCSearch(core.BaseSearchClass): Parameters ---------- - kwargs: dict + **kwargs: A dictionary containing 'type' of pdf and shape parameters """ @@ -1139,7 +1159,7 @@ class MCMCSearch(core.BaseSearchClass): pickle.dump(d, File) def get_saved_data_dictionary(self): - """ Returns dictionary of the data saved as pickle """ + """ Returns dictionary of the data saved in the pickle """ with open(self.pickle_path, "r") as File: d = pickle.load(File) return d @@ -1264,6 +1284,12 @@ class MCMCSearch(core.BaseSearchClass): ---------- threshold: float [0, 1] Fraction of the uniform prior to test (at upper and lower bound) + + Returns + ------- + return_flag: bool + IF true, the samples are railing + """ return_flag = False for s, k in zip(self.samples.T, self.theta_keys): @@ -1307,6 +1333,7 @@ class MCMCSearch(core.BaseSearchClass): f.write('{} = {:1.16e}\n'.format(key, val)) def generate_loudest(self): + """ Use lalapps_ComputeFstatistic_v2 to produce a .loudest file """ self.write_par() params = read_par(label=self.label, outdir=self.outdir) for key in ['Alpha', 'Delta', 'F0', 'F1']: @@ -1322,6 +1349,7 @@ class MCMCSearch(core.BaseSearchClass): subprocess.call([cmd], shell=True) def write_prior_table(self): + """ Generate a .tex file of the prior """ with open('{}/{}_prior.tex'.format(self.outdir, self.label), 'w') as f: f.write(r"\begin{tabular}{c l c} \hline" + '\n' r"Parameter & & & \\ \hhline{====}") @@ -1482,7 +1510,64 @@ class MCMCSearch(core.BaseSearchClass): class MCMCGlitchSearch(MCMCSearch): - """ MCMC search using the SemiCoherentGlitchSearch """ + """ MCMC search using the SemiCoherentGlitchSearch + + See the parent class MCMCSearch for all inherited methods and attributes + + Parameters + ---------- + label, outdir: str + A label and directory to read/write data from/to + sftfilepattern: str + Pattern to match SFTs using wildcards (*?) and ranges [0-9]; + mutiple patterns can be given separated by colons. + theta_prior: dict + Dictionary of priors and fixed values for the search parameters. + For each parameters (key of the dict), if it is to be held fixed + the value should be the constant float, if it is be searched, the + value should be a dictionary of the prior. + theta_initial: dict, array, (None) + Either a dictionary of distribution about which to distribute the + initial walkers about, an array (from which the walkers will be + scattered by scatter_val), or None in which case the prior is used. + scatter_val, float or ndim array + Size of scatter to use about the initialisation step, if given as + an array it must be of length ndim and the order is given by + theta_keys + nglitch: int + The number of glitches to allow + tref, minStartTime, maxStartTime: int + GPS seconds of the reference time, start time and end time + nsteps: list (m,) + List specifying the number of steps to take, the last two entries + give the nburn and nprod of the 'production' run, all entries + before are for iterative initialisation steps (usually just one) + e.g. [1000, 1000, 500]. + dtglitchmin: int + The minimum duration (in seconds) of a segment between two glitches + or a glitch and the start/end of the data + rhohatmax: float + Upper bound for the SNR scale parameter (required to normalise the + Bayes factor) - this needs to be carefully set when using the + evidence. + nwalkers, ntemps: int, + The number of walkers and temperates to use in the parallel + tempered PTSampler. + log10temperature_min float < 0 + The log_10(tmin) value, the set of betas passed to PTSampler are + generated from np.logspace(0, log10temperature_min, ntemps). + theta0_idx, int + Index (zero-based) of which segment the theta refers to - uyseful + if providing a tight prior on theta to allow the signal to jump + too theta (and not just from) + detectors: str + Two character reference to the data to use, specify None for no + contraint. + minCoverFreq, maxCoverFreq: float + Minimum and maximum instantaneous frequency which will be covered + over the SFT time span as passed to CreateFstatInput + + """ symbol_dictionary = dict( F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', Alpha=r'$\alpha$', @@ -1491,7 +1576,7 @@ class MCMCGlitchSearch(MCMCSearch): unit_dictionary = dict( F0='Hz', F1='Hz/s', F2='Hz/s$^2$', Alpha=r'rad', Delta='rad', delta_F0='Hz', delta_F1='Hz/s', tglitch='s') - rescale_dictionary = dict( + transform_dictionary = dict( tglitch={ 'multiplier': 1/86400., 'subtractor': 'minStartTime', @@ -1507,61 +1592,6 @@ class MCMCGlitchSearch(MCMCSearch): dtglitchmin=1*86400, theta0_idx=0, detectors=None, BSGL=False, minCoverFreq=None, maxCoverFreq=None, injectSources=None): - """ - Parameters - ---------- - label, outdir: str - A label and directory to read/write data from/to - sftfilepattern: str - Pattern to match SFTs using wildcards (*?) and ranges [0-9]; - mutiple patterns can be given separated by colons. - theta_prior: dict - Dictionary of priors and fixed values for the search parameters. - For each parameters (key of the dict), if it is to be held fixed - the value should be the constant float, if it is be searched, the - value should be a dictionary of the prior. - theta_initial: dict, array, (None) - Either a dictionary of distribution about which to distribute the - initial walkers about, an array (from which the walkers will be - scattered by scatter_val), or None in which case the prior is used. - scatter_val, float or ndim array - Size of scatter to use about the initialisation step, if given as - an array it must be of length ndim and the order is given by - theta_keys - nglitch: int - The number of glitches to allow - tref, minStartTime, maxStartTime: int - GPS seconds of the reference time, start time and end time - nsteps: list (m,) - List specifying the number of steps to take, the last two entries - give the nburn and nprod of the 'production' run, all entries - before are for iterative initialisation steps (usually just one) - e.g. [1000, 1000, 500]. - dtglitchmin: int - The minimum duration (in seconds) of a segment between two glitches - or a glitch and the start/end of the data - rhohatmax: float - Upper bound for the SNR scale parameter (required to normalise the - Bayes factor) - this needs to be carefully set when using the - evidence. - nwalkers, ntemps: int, - The number of walkers and temperates to use in the parallel - tempered PTSampler. - log10temperature_min float < 0 - The log_10(tmin) value, the set of betas passed to PTSampler are - generated from np.logspace(0, log10temperature_min, ntemps). - theta0_idx, int - Index (zero-based) of which segment the theta refers to - uyseful - if providing a tight prior on theta to allow the signal to jump - too theta (and not just from) - detectors: str - Two character reference to the data to use, specify None for no - contraint. - minCoverFreq, maxCoverFreq: float - Minimum and maximum instantaneous frequency which will be covered - over the SFT time span as passed to CreateFstatInput - - """ if os.path.isdir(outdir) is False: os.mkdir(outdir) @@ -1589,12 +1619,12 @@ class MCMCGlitchSearch(MCMCSearch): def _initiate_search_object(self): logging.info('Setting up search object') self.search = core.SemiCoherentGlitchSearch( - label=self.label, outdir=self.outdir, sftfilepattern=self.sftfilepattern, - tref=self.tref, minStartTime=self.minStartTime, - maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq, - maxCoverFreq=self.maxCoverFreq, detectors=self.detectors, BSGL=self.BSGL, - nglitch=self.nglitch, theta0_idx=self.theta0_idx, - injectSources=self.injectSources) + label=self.label, outdir=self.outdir, + sftfilepattern=self.sftfilepattern, tref=self.tref, + minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, + minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq, + detectors=self.detectors, BSGL=self.BSGL, nglitch=self.nglitch, + theta0_idx=self.theta0_idx, injectSources=self.injectSources) def logp(self, theta_vals, theta_prior, theta_keys, search): if self.nglitch > 1: @@ -1759,9 +1789,6 @@ class MCMCSemiCoherentSearch(MCMCSearch): detectors=None, BSGL=False, minStartTime=None, maxStartTime=None, minCoverFreq=None, maxCoverFreq=None, injectSources=None, assumeSqrtSX=None): - """ - - """ if os.path.isdir(outdir) is False: os.mkdir(outdir) @@ -2083,7 +2110,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): class MCMCTransientSearch(MCMCSearch): - """ MCMC search for a transient signal using the ComputeFstat """ + """ MCMC search for a transient signal using ComputeFstat """ symbol_dictionary = dict( F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', @@ -2093,7 +2120,7 @@ class MCMCTransientSearch(MCMCSearch): F0='Hz', F1='Hz/s', F2='Hz/s$^2$', Alpha=r'rad', Delta='rad', transient_tstart='s', transient_duration='s') - rescale_dictionary = dict( + transform_dictionary = dict( transient_duration={'multiplier': 1/86400., 'unit': 'day', 'symbol': 'Transient duration'},