Commit 4e796984 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Move initialisation to special function

This adds a new method `setup_initialisation` which can be called to run
the initialisation step. Splitting it out from the main functionality
should make it easier for new users.
parent 4632121f
...@@ -41,11 +41,10 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -41,11 +41,10 @@ class MCMCSearch(core.BaseSearchClass):
detectors: str detectors: str
Two character reference to the detectors to use, specify None for no Two character reference to the detectors to use, specify None for no
contraint and comma separate for multiple references. contraint and comma separate for multiple references.
nsteps: list (m,) nsteps: list (2,)
List specifying the number of steps to take, the last two entries Number of burn-in and production steps to take, [nburn, nprod]. See
give the nburn and nprod of the 'production' run, all entries `pyfstat.MCMCSearch.setup_initialisation()` for details on adding
before are for iterative initialisation steps (usually just one) initialisation steps.
e.g. [1000, 1000, 500].
nwalkers, ntemps: int, nwalkers, ntemps: int,
The number of walkers and temperates to use in the parallel The number of walkers and temperates to use in the parallel
tempered PTSampler. tempered PTSampler.
...@@ -103,7 +102,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -103,7 +102,7 @@ class MCMCSearch(core.BaseSearchClass):
maxStartTime, sftfilepattern=None, detectors=None, maxStartTime, sftfilepattern=None, detectors=None,
nsteps=[100, 100], nwalkers=100, ntemps=1, nsteps=[100, 100], nwalkers=100, ntemps=1,
log10temperature_min=-5, theta_initial=None, log10temperature_min=-5, theta_initial=None,
scatter_val=1e-10, rhohatmax=1000, binary=False, BSGL=False, rhohatmax=1000, binary=False, BSGL=False,
SSBprec=None, minCoverFreq=None, maxCoverFreq=None, SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
injectSources=None, assumeSqrtSX=None): injectSources=None, assumeSqrtSX=None):
...@@ -137,7 +136,6 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -137,7 +136,6 @@ class MCMCSearch(core.BaseSearchClass):
def _log_input(self): def _log_input(self):
logging.info('theta_prior = {}'.format(self.theta_prior)) logging.info('theta_prior = {}'.format(self.theta_prior))
logging.info('nwalkers={}'.format(self.nwalkers)) logging.info('nwalkers={}'.format(self.nwalkers))
logging.info('scatter_val = {}'.format(self.scatter_val))
logging.info('nsteps = {}'.format(self.nsteps)) logging.info('nsteps = {}'.format(self.nsteps))
logging.info('ntemps = {}'.format(self.ntemps)) logging.info('ntemps = {}'.format(self.ntemps))
logging.info('log10temperature_min = {}'.format( logging.info('log10temperature_min = {}'.format(
...@@ -278,6 +276,30 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -278,6 +276,30 @@ class MCMCSearch(core.BaseSearchClass):
else: else:
raise ValueError('test_type {} not understood'.format(test_type)) raise ValueError('test_type {} not understood'.format(test_type))
def setup_initialisation(self, nburn0, scatter_val=1e-10):
""" Add an initialisation step to the MCMC run
If called prior to `run()`, adds an intial step in which the MCMC
simulation is run for `nburn0` steps. After this, the MCMC simulation
continues in the usual manner (i.e. for nburn and nprod steps), but the
walkers are reset scattered around the maximum likelihood position
of the initialisation step.
Parameters
----------
nburn0: int
Number of initialisation steps to take
scatter_val: float
Relative number to scatter walkers around the maximum likelihood
position after the initialisation step
"""
logging.info('Setting up initialisation with nburn0={}, scatter_val={}'
.format(nburn0, scatter_val))
self.nsteps = [nburn0] + self.nsteps
self.scatter_val = scatter_val
def _test_autocorr_convergence(self, i, sampler, test=True, n_cut=5): def _test_autocorr_convergence(self, i, sampler, test=True, n_cut=5):
try: try:
acors = np.zeros((self.ntemps, self.ndim)) acors = np.zeros((self.ntemps, self.ndim))
...@@ -451,6 +473,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -451,6 +473,7 @@ class MCMCSearch(core.BaseSearchClass):
p0 = self._apply_corrections_to_p0(p0) p0 = self._apply_corrections_to_p0(p0)
self._check_initial_points(p0) self._check_initial_points(p0)
# Run initialisation steps if required
ninit_steps = len(self.nsteps) - 2 ninit_steps = len(self.nsteps) - 2
for j, n in enumerate(self.nsteps[:-2]): for j, n in enumerate(self.nsteps[:-2]):
logging.info('Running {}/{} initialisation with {} steps'.format( logging.info('Running {}/{} initialisation with {} steps'.format(
...@@ -1139,7 +1162,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -1139,7 +1162,7 @@ class MCMCSearch(core.BaseSearchClass):
def _get_data_dictionary_to_save(self): def _get_data_dictionary_to_save(self):
d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers, d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
ntemps=self.ntemps, theta_keys=self.theta_keys, ntemps=self.ntemps, theta_keys=self.theta_keys,
theta_prior=self.theta_prior, scatter_val=self.scatter_val, theta_prior=self.theta_prior,
log10temperature_min=self.log10temperature_min, log10temperature_min=self.log10temperature_min,
BSGL=self.BSGL) BSGL=self.BSGL)
return d return d
...@@ -1549,7 +1572,7 @@ class MCMCGlitchSearch(MCMCSearch): ...@@ -1549,7 +1572,7 @@ class MCMCGlitchSearch(MCMCSearch):
maxStartTime, sftfilepattern=None, detectors=None, maxStartTime, sftfilepattern=None, detectors=None,
nsteps=[100, 100], nwalkers=100, ntemps=1, nsteps=[100, 100], nwalkers=100, ntemps=1,
log10temperature_min=-5, theta_initial=None, log10temperature_min=-5, theta_initial=None,
scatter_val=1e-10, rhohatmax=1000, binary=False, BSGL=False, rhohatmax=1000, binary=False, BSGL=False,
SSBprec=None, minCoverFreq=None, maxCoverFreq=None, SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
injectSources=None, assumeSqrtSX=None, injectSources=None, assumeSqrtSX=None,
dtglitchmin=1*86400, theta0_idx=0, nglitch=1): dtglitchmin=1*86400, theta0_idx=0, nglitch=1):
...@@ -1678,7 +1701,7 @@ class MCMCGlitchSearch(MCMCSearch): ...@@ -1678,7 +1701,7 @@ class MCMCGlitchSearch(MCMCSearch):
def _get_data_dictionary_to_save(self): def _get_data_dictionary_to_save(self):
d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers, d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
ntemps=self.ntemps, theta_keys=self.theta_keys, ntemps=self.ntemps, theta_keys=self.theta_keys,
theta_prior=self.theta_prior, scatter_val=self.scatter_val, theta_prior=self.theta_prior,
log10temperature_min=self.log10temperature_min, log10temperature_min=self.log10temperature_min,
theta0_idx=self.theta0_idx, BSGL=self.BSGL) theta0_idx=self.theta0_idx, BSGL=self.BSGL)
return d return d
...@@ -1758,7 +1781,7 @@ class MCMCSemiCoherentSearch(MCMCSearch): ...@@ -1758,7 +1781,7 @@ class MCMCSemiCoherentSearch(MCMCSearch):
maxStartTime, sftfilepattern=None, detectors=None, maxStartTime, sftfilepattern=None, detectors=None,
nsteps=[100, 100], nwalkers=100, ntemps=1, nsteps=[100, 100], nwalkers=100, ntemps=1,
log10temperature_min=-5, theta_initial=None, log10temperature_min=-5, theta_initial=None,
scatter_val=1e-10, rhohatmax=1000, binary=False, BSGL=False, rhohatmax=1000, binary=False, BSGL=False,
SSBprec=None, minCoverFreq=None, maxCoverFreq=None, SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
injectSources=None, assumeSqrtSX=None, injectSources=None, assumeSqrtSX=None,
nsegs=None): nsegs=None):
...@@ -1792,7 +1815,7 @@ class MCMCSemiCoherentSearch(MCMCSearch): ...@@ -1792,7 +1815,7 @@ class MCMCSemiCoherentSearch(MCMCSearch):
def _get_data_dictionary_to_save(self): def _get_data_dictionary_to_save(self):
d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers, d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
ntemps=self.ntemps, theta_keys=self.theta_keys, ntemps=self.ntemps, theta_keys=self.theta_keys,
theta_prior=self.theta_prior, scatter_val=self.scatter_val, theta_prior=self.theta_prior,
log10temperature_min=self.log10temperature_min, log10temperature_min=self.log10temperature_min,
BSGL=self.BSGL, nsegs=self.nsegs) BSGL=self.BSGL, nsegs=self.nsegs)
return d return d
...@@ -1830,7 +1853,6 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): ...@@ -1830,7 +1853,6 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
def _get_data_dictionary_to_save(self): def _get_data_dictionary_to_save(self):
d = dict(nwalkers=self.nwalkers, ntemps=self.ntemps, d = dict(nwalkers=self.nwalkers, ntemps=self.ntemps,
theta_keys=self.theta_keys, theta_prior=self.theta_prior, theta_keys=self.theta_keys, theta_prior=self.theta_prior,
scatter_val=self.scatter_val,
log10temperature_min=self.log10temperature_min, log10temperature_min=self.log10temperature_min,
BSGL=self.BSGL, run_setup=self.run_setup) BSGL=self.BSGL, run_setup=self.run_setup)
return d return d
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment