Commit 6090bd65 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

General improvements to the documentation and imports

parent e22e9b59
from __future__ import division
from __future__ import division as _division
from .core import BaseSearchClass, ComputeFstat, Writer
from .mcmc_based_searches import *
from .grid_based_searches import *
from .helper_functions import texify_float
from .core import BaseSearchClass, ComputeFstat, Writer, SemiCoherentSearch, SemiCoherentGlitchSearch
from .mcmc_based_searches import MCMCSearch, MCMCGlitchSearch, MCMCSemiCoherentSearch, MCMCFollowUpSearch, MCMCTransientSearch
from .grid_based_searches import GridSearch, GridUniformPriorSearch, GridGlitchSearch
......@@ -37,7 +37,7 @@ class BaseSearchClass(object):
earth_ephem_default = earth_ephem
sun_ephem_default = sun_ephem
def add_log_file(self):
def _add_log_file(self):
""" Log output to a file, requires class to have outdir and label """
logfilename = '{}/{}.log'.format(self.outdir, self.label)
fh = logging.FileHandler(logfilename)
......@@ -47,7 +47,7 @@ class BaseSearchClass(object):
datefmt='%y-%m-%d %H:%M'))
logging.getLogger().addHandler(fh)
def shift_matrix(self, n, dT):
def _shift_matrix(self, n, dT):
""" Generate the shift matrix
Parameters
......@@ -78,7 +78,7 @@ class BaseSearchClass(object):
m[i, j] = float(dT)**(j-i) / factorial(j-i)
return m
def shift_coefficients(self, theta, dT):
def _shift_coefficients(self, theta, dT):
""" Shift a set of coefficients by dT
Parameters
......@@ -96,30 +96,30 @@ class BaseSearchClass(object):
"""
n = len(theta)
m = self.shift_matrix(n, dT)
m = self._shift_matrix(n, dT)
return np.dot(m, theta)
def calculate_thetas(self, theta, delta_thetas, tbounds, theta0_idx=0):
def _calculate_thetas(self, theta, delta_thetas, tbounds, theta0_idx=0):
""" Calculates the set of coefficients for the post-glitch signal """
thetas = [theta]
for i, dt in enumerate(delta_thetas):
if i < theta0_idx:
pre_theta_at_ith_glitch = self.shift_coefficients(
pre_theta_at_ith_glitch = self._shift_coefficients(
thetas[0], tbounds[i+1] - self.tref)
post_theta_at_ith_glitch = pre_theta_at_ith_glitch - dt
thetas.insert(0, self.shift_coefficients(
thetas.insert(0, self._shift_coefficients(
post_theta_at_ith_glitch, self.tref - tbounds[i+1]))
elif i >= theta0_idx:
pre_theta_at_ith_glitch = self.shift_coefficients(
pre_theta_at_ith_glitch = self._shift_coefficients(
thetas[i], tbounds[i+1] - self.tref)
post_theta_at_ith_glitch = pre_theta_at_ith_glitch + dt
thetas.append(self.shift_coefficients(
thetas.append(self._shift_coefficients(
post_theta_at_ith_glitch, self.tref - tbounds[i+1]))
self.thetas_at_tref = thetas
return thetas
def generate_loudest(self):
def _generate_loudest(self):
params = read_par(self.label, self.outdir)
for key in ['Alpha', 'Delta', 'F0', 'F1']:
if key not in params:
......@@ -133,7 +133,7 @@ class BaseSearchClass(object):
self.maxStartTime)
subprocess.call([cmd], shell=True)
def get_list_of_matching_sfts(self):
def _get_list_of_matching_sfts(self):
matches = [glob.glob(p) for p in self.sftfilepath]
matches = [item for sublist in matches for item in sublist]
if len(matches) > 0:
......@@ -685,7 +685,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
delta_thetas = np.atleast_2d(
np.array([delta_phi, delta_F0s, delta_F1s, delta_F2]).T)
thetas = self.calculate_thetas(theta, delta_thetas, tboundaries,
thetas = self._calculate_thetas(theta, delta_thetas, tboundaries,
theta0_idx=self.theta0_idx)
twoFSum = 0
......@@ -713,9 +713,9 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
delta_theta = [delta_F0, delta_F1, 0]
tref = self.tref
theta_at_glitch = self.shift_coefficients(theta, tglitch - tref)
theta_at_glitch = self._shift_coefficients(theta, tglitch - tref)
theta_post_glitch_at_glitch = theta_at_glitch + delta_theta
theta_post_glitch = self.shift_coefficients(
theta_post_glitch = self._shift_coefficients(
theta_post_glitch_at_glitch, tref - tglitch)
twoFsegA = self.run_computefstatistic_single_point(
......@@ -849,7 +849,7 @@ transientTauDays={:1.3f}\n""")
"""
thetas = self.calculate_thetas(self.theta, self.delta_thetas,
thetas = self._calculate_thetas(self.theta, self.delta_thetas,
self.tbounds)
content = ''
......
......@@ -65,6 +65,7 @@ def set_up_command_line_arguments():
def set_up_ephemeris_configuration():
""" Returns the earth_ephem and sun_ephem """
config_file = os.path.expanduser('~')+'/.pyfstat.conf'
if os.path.isfile(config_file):
d = {}
......
......@@ -44,7 +44,8 @@ class MCMCSearch(core.BaseSearchClass):
label, outdir: str
A label and directory to read/write data from/to
sftfilepath: str
File patern to match SFTs
Pattern to match SFTs using wildcards (*?) and ranges [0-9];
mutiple patterns can be given separated by colons.
theta_prior: dict
Dictionary of priors and fixed values for the search parameters.
For each parameters (key of the dict), if it is to be held fixed
......@@ -84,12 +85,12 @@ class MCMCSearch(core.BaseSearchClass):
if os.path.isdir(outdir) is False:
os.mkdir(outdir)
self.add_log_file()
self._add_log_file()
logging.info(
'Set-up MCMC search for model {} on data {}'.format(
self.label, self.sftfilepath))
self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
self.unpack_input_theta()
self._unpack_input_theta()
self.ndim = len(self.theta_keys)
if self.log10temperature_min:
self.betas = np.logspace(0, self.log10temperature_min, self.ntemps)
......@@ -104,9 +105,9 @@ class MCMCSearch(core.BaseSearchClass):
if args.clean and os.path.isfile(self.pickle_path):
os.rename(self.pickle_path, self.pickle_path+".old")
self.log_input()
self._log_input()
def log_input(self):
def _log_input(self):
logging.info('theta_prior = {}'.format(self.theta_prior))
logging.info('nwalkers={}'.format(self.nwalkers))
logging.info('scatter_val = {}'.format(self.scatter_val))
......@@ -115,7 +116,7 @@ class MCMCSearch(core.BaseSearchClass):
logging.info('log10temperature_min = {}'.format(
self.log10temperature_min))
def initiate_search_object(self):
def _initiate_search_object(self):
logging.info('Setting up search object')
self.search = core.ComputeFstat(
tref=self.tref, sftfilepath=self.sftfilepath,
......@@ -127,7 +128,7 @@ class MCMCSearch(core.BaseSearchClass):
assumeSqrtSX=self.assumeSqrtSX)
def logp(self, theta_vals, theta_prior, theta_keys, search):
H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
zip(theta_vals, theta_keys)]
return np.sum(H)
......@@ -138,7 +139,7 @@ class MCMCSearch(core.BaseSearchClass):
*self.fixed_theta)
return FS
def unpack_input_theta(self):
def _unpack_input_theta(self):
full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']
if self.binary:
full_theta_keys += [
......@@ -179,7 +180,7 @@ class MCMCSearch(core.BaseSearchClass):
self.theta_symbols = [self.theta_symbols[i] for i in idxs]
self.theta_keys = [self.theta_keys[i] for i in idxs]
def check_initial_points(self, p0):
def _check_initial_points(self, p0):
for nt in range(self.ntemps):
logging.info('Checking temperature {} chains'.format(nt))
initial_priors = np.array([
......@@ -193,10 +194,10 @@ class MCMCSearch(core.BaseSearchClass):
.format(len(initial_priors),
number_of_initial_out_of_bounds))
p0 = self.generate_new_p0_to_fix_initial_points(
p0 = self._generate_new_p0_to_fix_initial_points(
p0, nt, initial_priors)
def generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors):
def _generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors):
logging.info('Attempting to correct intial values')
idxs = np.arange(self.nwalkers)[initial_priors == -np.inf]
count = 0
......@@ -217,7 +218,7 @@ class MCMCSearch(core.BaseSearchClass):
return p0
def OLD_run_sampler_with_progress_bar(self, sampler, ns, p0):
def _OLD_run_sampler_with_progress_bar(self, sampler, ns, p0):
for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
pass
return sampler
......@@ -271,7 +272,7 @@ class MCMCSearch(core.BaseSearchClass):
self.convergence_number = 0
self.convergence_plot_upper_lim = convergence_plot_upper_lim
def get_convergence_statistic(self, i, sampler):
def _get_convergence_statistic(self, i, sampler):
s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
within_std = np.mean(np.var(s, axis=1), axis=0)
per_walker_mean = np.mean(s, axis=1)
......@@ -286,25 +287,25 @@ class MCMCSearch(core.BaseSearchClass):
self.convergence_diagnosticx.append(i - self.convergence_length/2)
return c
def burnin_convergence_test(self, i, sampler, nburn):
def _burnin_convergence_test(self, i, sampler, nburn):
if i < self.convergence_burnin_fraction*nburn:
return False
if np.mod(i+1, self.convergence_period) != 0:
return False
c = self.get_convergence_statistic(i, sampler)
c = self._get_convergence_statistic(i, sampler)
if np.all(c < self.convergence_threshold):
self.convergence_number += 1
else:
self.convergence_number = 0
return self.convergence_number > self.convergence_threshold_number
def prod_convergence_test(self, i, sampler, nburn):
def _prod_convergence_test(self, i, sampler, nburn):
testA = i > nburn + self.convergence_length
testB = np.mod(i+1, self.convergence_period) == 0
if testA and testB:
self.get_convergence_statistic(i, sampler)
self._get_convergence_statistic(i, sampler)
def check_production_convergence(self, k):
def _check_production_convergence(self, k):
bools = np.any(
np.array(self.convergence_diagnostic)[k:, :]
> self.convergence_prod_threshold, axis=1)
......@@ -313,13 +314,13 @@ class MCMCSearch(core.BaseSearchClass):
'{} convergence tests in the production run of {} failed'
.format(np.sum(bools), len(bools)))
def run_sampler(self, sampler, p0, nprod=0, nburn=0):
def _run_sampler(self, sampler, p0, nprod=0, nburn=0):
if hasattr(self, 'convergence_period'):
logging.info('Running {} burn-in steps with convergence testing'
.format(nburn))
iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
for i, output in enumerate(iterator):
if self.burnin_convergence_test(i, sampler, nburn):
if self._burnin_convergence_test(i, sampler, nburn):
logging.info(
'Converged at {} before max number {} of steps reached'
.format(i, nburn))
......@@ -331,9 +332,9 @@ class MCMCSearch(core.BaseSearchClass):
k = len(self.convergence_diagnostic)
for result in tqdm(sampler.sample(output[0], iterations=nprod),
total=nprod):
self.prod_convergence_test(j, sampler, nburn)
self._prod_convergence_test(j, sampler, nburn)
j += 1
self.check_production_convergence(k)
self._check_production_convergence(k)
return sampler
else:
for result in tqdm(sampler.sample(p0, iterations=nburn+nprod),
......@@ -342,50 +343,51 @@ class MCMCSearch(core.BaseSearchClass):
return sampler
def run(self, proposal_scale_factor=2, create_plots=True, **kwargs):
""" Run the MCMC simulatation """
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use()
if self.old_data_is_okay_to_use is True:
logging.warning('Using saved data from {}'.format(
self.pickle_path))
d = self.get_saved_data()
d = self.get_saved_data_dictionary()
self.sampler = d['sampler']
self.samples = d['samples']
self.lnprobs = d['lnprobs']
self.lnlikes = d['lnlikes']
return
self.initiate_search_object()
self._initiate_search_object()
sampler = emcee.PTSampler(
self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
logpargs=(self.theta_prior, self.theta_keys, self.search),
loglargs=(self.search,), betas=self.betas, a=proposal_scale_factor)
p0 = self.generate_initial_p0()
p0 = self.apply_corrections_to_p0(p0)
self.check_initial_points(p0)
p0 = self._generate_initial_p0()
p0 = self._apply_corrections_to_p0(p0)
self._check_initial_points(p0)
ninit_steps = len(self.nsteps) - 2
for j, n in enumerate(self.nsteps[:-2]):
logging.info('Running {}/{} initialisation with {} steps'.format(
j, ninit_steps, n))
sampler = self.run_sampler(sampler, p0, nburn=n)
sampler = self._run_sampler(sampler, p0, nburn=n)
logging.info("Mean acceptance fraction: {}"
.format(np.mean(sampler.acceptance_fraction, axis=1)))
if self.ntemps > 1:
logging.info("Tswap acceptance fraction: {}"
.format(sampler.tswap_acceptance_fraction))
if create_plots:
fig, axes = self.plot_walkers(sampler,
fig, axes = self._plot_walkers(sampler,
symbols=self.theta_symbols,
**kwargs)
fig.tight_layout()
fig.savefig('{}/{}_init_{}_walkers.png'.format(
self.outdir, self.label, j), dpi=400)
p0 = self.get_new_p0(sampler)
p0 = self.apply_corrections_to_p0(p0)
self.check_initial_points(p0)
p0 = self._get_new_p0(sampler)
p0 = self._apply_corrections_to_p0(p0)
self._check_initial_points(p0)
sampler.reset()
if len(self.nsteps) > 1:
......@@ -395,7 +397,7 @@ class MCMCSearch(core.BaseSearchClass):
nprod = self.nsteps[-1]
logging.info('Running final burn and prod with {} steps'.format(
nburn+nprod))
sampler = self.run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
logging.info("Mean acceptance fraction: {}"
.format(np.mean(sampler.acceptance_fraction, axis=1)))
if self.ntemps > 1:
......@@ -403,7 +405,7 @@ class MCMCSearch(core.BaseSearchClass):
.format(sampler.tswap_acceptance_fraction))
if create_plots:
fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
fig, axes = self._plot_walkers(sampler, symbols=self.theta_symbols,
nprod=nprod, **kwargs)
fig.tight_layout()
fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
......@@ -416,9 +418,9 @@ class MCMCSearch(core.BaseSearchClass):
self.samples = samples
self.lnprobs = lnprobs
self.lnlikes = lnlikes
self.save_data(sampler, samples, lnprobs, lnlikes)
self._save_data(sampler, samples, lnprobs, lnlikes)
def get_rescale_multiplier_for_key(self, key):
def _get_rescale_multiplier_for_key(self, key):
""" Get the rescale multiplier from the rescale_dictionary
Can either be a float, a string (in which case it is interpretted as
......@@ -443,7 +445,7 @@ class MCMCSearch(core.BaseSearchClass):
multiplier = 1
return multiplier
def get_rescale_subtractor_for_key(self, key):
def _get_rescale_subtractor_for_key(self, key):
""" Get the rescale subtractor from the rescale_dictionary
Can either be a float, a string (in which case it is interpretted as
......@@ -468,21 +470,21 @@ class MCMCSearch(core.BaseSearchClass):
subtractor = 0
return subtractor
def scale_samples(self, samples, theta_keys):
def _scale_samples(self, samples, theta_keys):
""" Scale the samples using the rescale_dictionary """
for key in theta_keys:
if key in self.rescale_dictionary:
idx = theta_keys.index(key)
s = samples[:, idx]
subtractor = self.get_rescale_subtractor_for_key(key)
subtractor = self._get_rescale_subtractor_for_key(key)
s = s - subtractor
multiplier = self.get_rescale_multiplier_for_key(key)
multiplier = self._get_rescale_multiplier_for_key(key)
s *= multiplier
samples[:, idx] = s
return samples
def get_labels(self):
def _get_labels(self):
""" Combine the units, symbols and rescaling to give labels """
labels = []
......@@ -503,9 +505,38 @@ class MCMCSearch(core.BaseSearchClass):
labels.append(label)
return labels
def plot_corner(self, figsize=(7, 7), tglitch_ratio=False,
add_prior=False, nstds=None, label_offset=0.4,
dpi=300, rc_context={}, **kwargs):
def plot_corner(self, figsize=(7, 7), add_prior=False, nstds=None,
label_offset=0.4, dpi=300, rc_context={},
tglitch_ratio=False, **kwargs):
""" Generate a corner plot of the posterior
Using the `corner` package (https://pypi.python.org/pypi/corner/),
generate estimates of the posterior from the production samples.
Parameters
----------
figsize: tuple (7, 7)
Figure size in inches (passed to plt.subplots)
add_prior: bool
If true, plot the prior as a red line
nstds: float
The number of standard deviations to plot centered on the mean
label_offset: float
Offset the labels from the plot: useful to precent overlapping the
tick labels with the axis labels
dpi: int
Passed to plt.savefig
rc_context: dict
Dictionary of rc values to set while generating the figure (see
matplotlib rc for more details)
tglitch_ratio: bool
If true, and tglitch is a parameter, plot posteriors as the
fractional time at which the glitch occurs instead of the actual
time
Note: kwargs are passed on to corner.coner
"""
if self.ndim < 2:
with plt.rc_context(rc_context):
......@@ -522,9 +553,9 @@ class MCMCSearch(core.BaseSearchClass):
figsize=figsize)
samples_plt = copy.copy(self.samples)
labels = self.get_labels()
labels = self._get_labels()
samples_plt = self.scale_samples(samples_plt, self.theta_keys)
samples_plt = self._scale_samples(samples_plt, self.theta_keys)
if tglitch_ratio:
for j, k in enumerate(self.theta_keys):
......@@ -571,20 +602,20 @@ class MCMCSearch(core.BaseSearchClass):
fig.subplots_adjust(hspace=0.05, wspace=0.05)
if add_prior:
self.add_prior_to_corner(axes, self.samples)
self._add_prior_to_corner(axes, self.samples)
fig_triangle.savefig('{}/{}_corner.png'.format(
self.outdir, self.label), dpi=dpi)
def add_prior_to_corner(self, axes, samples):
def _add_prior_to_corner(self, axes, samples):
for i, key in enumerate(self.theta_keys):
ax = axes[i][i]
xlim = ax.get_xlim()
s = samples[:, i]
prior = self.generic_lnprior(**self.theta_prior[key])
prior = self._generic_lnprior(**self.theta_prior[key])
x = np.linspace(s.min(), s.max(), 100)
multiplier = self.get_rescale_multiplier_for_key(key)
subtractor = self.get_rescale_subtractor_for_key(key)
multiplier = self._get_rescale_multiplier_for_key(key)
subtractor = self._get_rescale_subtractor_for_key(key)
ax2 = ax.twinx()
ax2.get_yaxis().set_visible(False)
ax2.plot((x-subtractor)*multiplier, [prior(xi) for xi in x], '-r')
......@@ -598,7 +629,7 @@ class MCMCSearch(core.BaseSearchClass):
for i, (ax, key) in enumerate(zip(axes, self.theta_keys)):
prior_dict = self.theta_prior[key]
prior_func = self.generic_lnprior(**prior_dict)
prior_func = self._generic_lnprior(**prior_dict)
if prior_dict['type'] == 'unif':
x = np.linspace(prior_dict['lower'], prior_dict['upper'], N)
prior = prior_func(x)
......@@ -643,13 +674,17 @@ class MCMCSearch(core.BaseSearchClass):
self.outdir, self.label))
def plot_cumulative_max(self, **kwargs):
""" Plot the cumulative twoF for the maximum posterior estimate
See the pyfstat.core.plot_twoF_cumulative function for further details
"""
d, maxtwoF = self.get_max_twoF()
for key, val in self.theta_prior.iteritems():
if key not in d:
d[key] = val
if hasattr(self, 'search') is False:
self.initiate_search_object()
self._initiate_search_object()
if self.binary is False:
self.search.plot_twoF_cumulative(
self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'],
......@@ -663,7 +698,7 @@ class MCMCSearch(core.BaseSearchClass):
period=d['period'], ecc=d['ecc'], argp=d['argp'], tp=d['argp'],
tstart=self.minStartTime, tend=self.maxStartTime, **kwargs)
def generic_lnprior(self, **kwargs):
def _generic_lnprior(self, **kwargs):
""" Return a lambda function of the pdf
Parameters
......@@ -715,7 +750,7 @@ class MCMCSearch(core.BaseSearchClass):
logging.info("kwargs:", kwargs)
raise ValueError("Print unrecognise distribution")
def generate_rv(self, **kwargs):
def _generate_rv(self, **kwargs):
dist_type = kwargs.pop('type')
if dist_type == "unif":
return np.random.uniform(low=kwargs['lower'], high=kwargs['upper'])
......@@ -733,10 +768,10 @@ class MCMCSearch(core.BaseSearchClass):
else:
raise ValueError("dist_type {} unknown".format(dist_type))
def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
lw=0.1, nprod=0, add_det_stat_burnin=False,
fig=None, axes=None, xoffset=0, plot_det_stat=False,
context='classic', subtractions=None, labelpad=0.05):
def _plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k",
temp=0, lw=0.1, nprod=0, add_det_stat_burnin=False,
fig=None, axes=None, xoffset=0, plot_det_stat=False,
context='classic', subtractions=None, labelpad=0.05):
""" Plot all the chains from a sampler """
if np.ndim(axes) > 1:
......@@ -865,48 +900,48 @@ class MCMCSearch(core.BaseSearchClass):
axes[-2].set_xlabel(r'$\textrm{Number of steps}$', labelpad=0.2)
return fig, axes
def apply_corrections_to_p0(self, p0):
def _apply_corrections_to_p0(self, p0):
""" Apply any correction to the initial p0 values """
return p0
def generate_scattered_p0(self, p):
def _generate_scattered_p0(self, p):
""" Generate a set of p0s scattered about p """
p0 = [[p + self.scatter_val * p * np.random.randn(self.ndim)
for i in xrange(self.nwalkers)]
for j in xrange(self.ntemps)]
return p0
def generate_initial_p0(self):
def _generate_initial_p0(self):
""" Generate a set of init vals for the walkers """
if type(self.theta_initial) == dict:
logging.info('Generate initial values from initial dictionary')
if hasattr(self, 'nglitch') and self.nglitch > 1:
raise ValueError('Initial dict not implemented for nglitch>1')
p0 = [[[self.generate_rv(**self.theta_initial[key])
p0 = [[[self._generate_rv(**self.theta_initial[key])
for key in self.theta_keys]
for i in range(self.nwalkers)]
for j in range(self.ntemps)]
elif type(self.theta_initial) == list:
logging.info('Generate initial values from list of theta_initial')
p0 = [[[self.generate_rv(**val)
p0 = [[[self._generate_rv(**val)
for val in self.theta_initial]
for i in range(self.nwalkers)]
for j in range(self.ntemps)]
elif self.theta_initial is None:
logging.info('Generate initial values from prior dictionary')
p0 = [[[self.generate_rv(**self.theta_prior[key])
p0 = [[[self._generate_rv(**self.theta_prior[key])
for key in self.theta_keys]
for i in range(self.nwalkers)]
for j in range(self.ntemps)]
elif len(self.theta_initial) == self.ndim:
p0 = self.generate_scattered_p0(self.theta_initial)
p0 = self._generate_scattered_p0(self.theta_initial)
else:
raise ValueError('theta_initial not understood')
return p0
def get_new_p0(self, sampler):
def _get_new_p0(self, sampler):
""" Returns new initial positions for walkers are burn0 stage
This returns new positions for all walkers by scattering points about
......@@ -936,7 +971,7 @@ class MCMCSearch(core.BaseSearchClass):
lnp_finite[np.isinf(lnp)] = np.nan
idx = np.unravel_index(np.nanargmax(lnp_finite), lnp_finite.shape)
p = pF[idx]
p0 = self.generate_scattered_p0(p)
p0 = self._generate_scattered_p0(p)
self.search.BSGL = False
twoF = self.logl(p, self.search)
......@@ -948,7 +983,7 @@ class MCMCSearch(core.BaseSearchClass):