diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index 67a21080a39ae1a3fe75a3bbeb4ce45ba9f57a31..1edffa4623110deae488653bd932dff45c75286e 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -41,11 +41,10 @@ class MCMCSearch(core.BaseSearchClass): detectors: str Two character reference to the detectors to use, specify None for no contraint and comma separate for multiple references. - nsteps: list (m,) - List specifying the number of steps to take, the last two entries - give the nburn and nprod of the 'production' run, all entries - before are for iterative initialisation steps (usually just one) - e.g. [1000, 1000, 500]. + nsteps: list (2,) + Number of burn-in and production steps to take, [nburn, nprod]. See + `pyfstat.MCMCSearch.setup_initialisation()` for details on adding + initialisation steps. nwalkers, ntemps: int, The number of walkers and temperates to use in the parallel tempered PTSampler. @@ -103,7 +102,7 @@ class MCMCSearch(core.BaseSearchClass): maxStartTime, sftfilepattern=None, detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1, log10temperature_min=-5, theta_initial=None, - scatter_val=1e-10, rhohatmax=1000, binary=False, BSGL=False, + rhohatmax=1000, binary=False, BSGL=False, SSBprec=None, minCoverFreq=None, maxCoverFreq=None, injectSources=None, assumeSqrtSX=None): @@ -137,7 +136,6 @@ class MCMCSearch(core.BaseSearchClass): def _log_input(self): logging.info('theta_prior = {}'.format(self.theta_prior)) logging.info('nwalkers={}'.format(self.nwalkers)) - logging.info('scatter_val = {}'.format(self.scatter_val)) logging.info('nsteps = {}'.format(self.nsteps)) logging.info('ntemps = {}'.format(self.ntemps)) logging.info('log10temperature_min = {}'.format( @@ -278,6 +276,30 @@ class MCMCSearch(core.BaseSearchClass): else: raise ValueError('test_type {} not understood'.format(test_type)) + def setup_initialisation(self, nburn0, scatter_val=1e-10): + """ Add an initialisation step to the MCMC run + + If called prior to `run()`, adds an intial step in which the MCMC + simulation is run for `nburn0` steps. After this, the MCMC simulation + continues in the usual manner (i.e. for nburn and nprod steps), but the + walkers are reset scattered around the maximum likelihood position + of the initialisation step. + + Parameters + ---------- + nburn0: int + Number of initialisation steps to take + scatter_val: float + Relative number to scatter walkers around the maximum likelihood + position after the initialisation step + + """ + + logging.info('Setting up initialisation with nburn0={}, scatter_val={}' + .format(nburn0, scatter_val)) + self.nsteps = [nburn0] + self.nsteps + self.scatter_val = scatter_val + def _test_autocorr_convergence(self, i, sampler, test=True, n_cut=5): try: acors = np.zeros((self.ntemps, self.ndim)) @@ -451,6 +473,7 @@ class MCMCSearch(core.BaseSearchClass): p0 = self._apply_corrections_to_p0(p0) self._check_initial_points(p0) + # Run initialisation steps if required ninit_steps = len(self.nsteps) - 2 for j, n in enumerate(self.nsteps[:-2]): logging.info('Running {}/{} initialisation with {} steps'.format( @@ -1139,7 +1162,7 @@ class MCMCSearch(core.BaseSearchClass): def _get_data_dictionary_to_save(self): d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers, ntemps=self.ntemps, theta_keys=self.theta_keys, - theta_prior=self.theta_prior, scatter_val=self.scatter_val, + theta_prior=self.theta_prior, log10temperature_min=self.log10temperature_min, BSGL=self.BSGL) return d @@ -1549,7 +1572,7 @@ class MCMCGlitchSearch(MCMCSearch): maxStartTime, sftfilepattern=None, detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1, log10temperature_min=-5, theta_initial=None, - scatter_val=1e-10, rhohatmax=1000, binary=False, BSGL=False, + rhohatmax=1000, binary=False, BSGL=False, SSBprec=None, minCoverFreq=None, maxCoverFreq=None, injectSources=None, assumeSqrtSX=None, dtglitchmin=1*86400, theta0_idx=0, nglitch=1): @@ -1678,7 +1701,7 @@ class MCMCGlitchSearch(MCMCSearch): def _get_data_dictionary_to_save(self): d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers, ntemps=self.ntemps, theta_keys=self.theta_keys, - theta_prior=self.theta_prior, scatter_val=self.scatter_val, + theta_prior=self.theta_prior, log10temperature_min=self.log10temperature_min, theta0_idx=self.theta0_idx, BSGL=self.BSGL) return d @@ -1758,7 +1781,7 @@ class MCMCSemiCoherentSearch(MCMCSearch): maxStartTime, sftfilepattern=None, detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1, log10temperature_min=-5, theta_initial=None, - scatter_val=1e-10, rhohatmax=1000, binary=False, BSGL=False, + rhohatmax=1000, binary=False, BSGL=False, SSBprec=None, minCoverFreq=None, maxCoverFreq=None, injectSources=None, assumeSqrtSX=None, nsegs=None): @@ -1792,7 +1815,7 @@ class MCMCSemiCoherentSearch(MCMCSearch): def _get_data_dictionary_to_save(self): d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers, ntemps=self.ntemps, theta_keys=self.theta_keys, - theta_prior=self.theta_prior, scatter_val=self.scatter_val, + theta_prior=self.theta_prior, log10temperature_min=self.log10temperature_min, BSGL=self.BSGL, nsegs=self.nsegs) return d @@ -1830,7 +1853,6 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): def _get_data_dictionary_to_save(self): d = dict(nwalkers=self.nwalkers, ntemps=self.ntemps, theta_keys=self.theta_keys, theta_prior=self.theta_prior, - scatter_val=self.scatter_val, log10temperature_min=self.log10temperature_min, BSGL=self.BSGL, run_setup=self.run_setup) return d