Commit 6090bd65 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

General improvements to the documentation and imports

parent e22e9b59
from __future__ import division from __future__ import division as _division
from .core import BaseSearchClass, ComputeFstat, Writer from .core import BaseSearchClass, ComputeFstat, Writer, SemiCoherentSearch, SemiCoherentGlitchSearch
from .mcmc_based_searches import * from .mcmc_based_searches import MCMCSearch, MCMCGlitchSearch, MCMCSemiCoherentSearch, MCMCFollowUpSearch, MCMCTransientSearch
from .grid_based_searches import * from .grid_based_searches import GridSearch, GridUniformPriorSearch, GridGlitchSearch
from .helper_functions import texify_float
...@@ -37,7 +37,7 @@ class BaseSearchClass(object): ...@@ -37,7 +37,7 @@ class BaseSearchClass(object):
earth_ephem_default = earth_ephem earth_ephem_default = earth_ephem
sun_ephem_default = sun_ephem sun_ephem_default = sun_ephem
def add_log_file(self): def _add_log_file(self):
""" Log output to a file, requires class to have outdir and label """ """ Log output to a file, requires class to have outdir and label """
logfilename = '{}/{}.log'.format(self.outdir, self.label) logfilename = '{}/{}.log'.format(self.outdir, self.label)
fh = logging.FileHandler(logfilename) fh = logging.FileHandler(logfilename)
...@@ -47,7 +47,7 @@ class BaseSearchClass(object): ...@@ -47,7 +47,7 @@ class BaseSearchClass(object):
datefmt='%y-%m-%d %H:%M')) datefmt='%y-%m-%d %H:%M'))
logging.getLogger().addHandler(fh) logging.getLogger().addHandler(fh)
def shift_matrix(self, n, dT): def _shift_matrix(self, n, dT):
""" Generate the shift matrix """ Generate the shift matrix
Parameters Parameters
...@@ -78,7 +78,7 @@ class BaseSearchClass(object): ...@@ -78,7 +78,7 @@ class BaseSearchClass(object):
m[i, j] = float(dT)**(j-i) / factorial(j-i) m[i, j] = float(dT)**(j-i) / factorial(j-i)
return m return m
def shift_coefficients(self, theta, dT): def _shift_coefficients(self, theta, dT):
""" Shift a set of coefficients by dT """ Shift a set of coefficients by dT
Parameters Parameters
...@@ -96,30 +96,30 @@ class BaseSearchClass(object): ...@@ -96,30 +96,30 @@ class BaseSearchClass(object):
""" """
n = len(theta) n = len(theta)
m = self.shift_matrix(n, dT) m = self._shift_matrix(n, dT)
return np.dot(m, theta) return np.dot(m, theta)
def calculate_thetas(self, theta, delta_thetas, tbounds, theta0_idx=0): def _calculate_thetas(self, theta, delta_thetas, tbounds, theta0_idx=0):
""" Calculates the set of coefficients for the post-glitch signal """ """ Calculates the set of coefficients for the post-glitch signal """
thetas = [theta] thetas = [theta]
for i, dt in enumerate(delta_thetas): for i, dt in enumerate(delta_thetas):
if i < theta0_idx: if i < theta0_idx:
pre_theta_at_ith_glitch = self.shift_coefficients( pre_theta_at_ith_glitch = self._shift_coefficients(
thetas[0], tbounds[i+1] - self.tref) thetas[0], tbounds[i+1] - self.tref)
post_theta_at_ith_glitch = pre_theta_at_ith_glitch - dt post_theta_at_ith_glitch = pre_theta_at_ith_glitch - dt
thetas.insert(0, self.shift_coefficients( thetas.insert(0, self._shift_coefficients(
post_theta_at_ith_glitch, self.tref - tbounds[i+1])) post_theta_at_ith_glitch, self.tref - tbounds[i+1]))
elif i >= theta0_idx: elif i >= theta0_idx:
pre_theta_at_ith_glitch = self.shift_coefficients( pre_theta_at_ith_glitch = self._shift_coefficients(
thetas[i], tbounds[i+1] - self.tref) thetas[i], tbounds[i+1] - self.tref)
post_theta_at_ith_glitch = pre_theta_at_ith_glitch + dt post_theta_at_ith_glitch = pre_theta_at_ith_glitch + dt
thetas.append(self.shift_coefficients( thetas.append(self._shift_coefficients(
post_theta_at_ith_glitch, self.tref - tbounds[i+1])) post_theta_at_ith_glitch, self.tref - tbounds[i+1]))
self.thetas_at_tref = thetas self.thetas_at_tref = thetas
return thetas return thetas
def generate_loudest(self): def _generate_loudest(self):
params = read_par(self.label, self.outdir) params = read_par(self.label, self.outdir)
for key in ['Alpha', 'Delta', 'F0', 'F1']: for key in ['Alpha', 'Delta', 'F0', 'F1']:
if key not in params: if key not in params:
...@@ -133,7 +133,7 @@ class BaseSearchClass(object): ...@@ -133,7 +133,7 @@ class BaseSearchClass(object):
self.maxStartTime) self.maxStartTime)
subprocess.call([cmd], shell=True) subprocess.call([cmd], shell=True)
def get_list_of_matching_sfts(self): def _get_list_of_matching_sfts(self):
matches = [glob.glob(p) for p in self.sftfilepath] matches = [glob.glob(p) for p in self.sftfilepath]
matches = [item for sublist in matches for item in sublist] matches = [item for sublist in matches for item in sublist]
if len(matches) > 0: if len(matches) > 0:
...@@ -685,7 +685,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): ...@@ -685,7 +685,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
delta_thetas = np.atleast_2d( delta_thetas = np.atleast_2d(
np.array([delta_phi, delta_F0s, delta_F1s, delta_F2]).T) np.array([delta_phi, delta_F0s, delta_F1s, delta_F2]).T)
thetas = self.calculate_thetas(theta, delta_thetas, tboundaries, thetas = self._calculate_thetas(theta, delta_thetas, tboundaries,
theta0_idx=self.theta0_idx) theta0_idx=self.theta0_idx)
twoFSum = 0 twoFSum = 0
...@@ -713,9 +713,9 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): ...@@ -713,9 +713,9 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
delta_theta = [delta_F0, delta_F1, 0] delta_theta = [delta_F0, delta_F1, 0]
tref = self.tref tref = self.tref
theta_at_glitch = self.shift_coefficients(theta, tglitch - tref) theta_at_glitch = self._shift_coefficients(theta, tglitch - tref)
theta_post_glitch_at_glitch = theta_at_glitch + delta_theta theta_post_glitch_at_glitch = theta_at_glitch + delta_theta
theta_post_glitch = self.shift_coefficients( theta_post_glitch = self._shift_coefficients(
theta_post_glitch_at_glitch, tref - tglitch) theta_post_glitch_at_glitch, tref - tglitch)
twoFsegA = self.run_computefstatistic_single_point( twoFsegA = self.run_computefstatistic_single_point(
...@@ -849,7 +849,7 @@ transientTauDays={:1.3f}\n""") ...@@ -849,7 +849,7 @@ transientTauDays={:1.3f}\n""")
""" """
thetas = self.calculate_thetas(self.theta, self.delta_thetas, thetas = self._calculate_thetas(self.theta, self.delta_thetas,
self.tbounds) self.tbounds)
content = '' content = ''
......
...@@ -65,6 +65,7 @@ def set_up_command_line_arguments(): ...@@ -65,6 +65,7 @@ def set_up_command_line_arguments():
def set_up_ephemeris_configuration(): def set_up_ephemeris_configuration():
""" Returns the earth_ephem and sun_ephem """
config_file = os.path.expanduser('~')+'/.pyfstat.conf' config_file = os.path.expanduser('~')+'/.pyfstat.conf'
if os.path.isfile(config_file): if os.path.isfile(config_file):
d = {} d = {}
......
...@@ -44,7 +44,8 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -44,7 +44,8 @@ class MCMCSearch(core.BaseSearchClass):
label, outdir: str label, outdir: str
A label and directory to read/write data from/to A label and directory to read/write data from/to
sftfilepath: str sftfilepath: str
File patern to match SFTs Pattern to match SFTs using wildcards (*?) and ranges [0-9];
mutiple patterns can be given separated by colons.
theta_prior: dict theta_prior: dict
Dictionary of priors and fixed values for the search parameters. Dictionary of priors and fixed values for the search parameters.
For each parameters (key of the dict), if it is to be held fixed For each parameters (key of the dict), if it is to be held fixed
...@@ -84,12 +85,12 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -84,12 +85,12 @@ class MCMCSearch(core.BaseSearchClass):
if os.path.isdir(outdir) is False: if os.path.isdir(outdir) is False:
os.mkdir(outdir) os.mkdir(outdir)
self.add_log_file() self._add_log_file()
logging.info( logging.info(
'Set-up MCMC search for model {} on data {}'.format( 'Set-up MCMC search for model {} on data {}'.format(
self.label, self.sftfilepath)) self.label, self.sftfilepath))
self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label) self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
self.unpack_input_theta() self._unpack_input_theta()
self.ndim = len(self.theta_keys) self.ndim = len(self.theta_keys)
if self.log10temperature_min: if self.log10temperature_min:
self.betas = np.logspace(0, self.log10temperature_min, self.ntemps) self.betas = np.logspace(0, self.log10temperature_min, self.ntemps)
...@@ -104,9 +105,9 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -104,9 +105,9 @@ class MCMCSearch(core.BaseSearchClass):
if args.clean and os.path.isfile(self.pickle_path): if args.clean and os.path.isfile(self.pickle_path):
os.rename(self.pickle_path, self.pickle_path+".old") os.rename(self.pickle_path, self.pickle_path+".old")
self.log_input() self._log_input()
def log_input(self): def _log_input(self):
logging.info('theta_prior = {}'.format(self.theta_prior)) logging.info('theta_prior = {}'.format(self.theta_prior))
logging.info('nwalkers={}'.format(self.nwalkers)) logging.info('nwalkers={}'.format(self.nwalkers))
logging.info('scatter_val = {}'.format(self.scatter_val)) logging.info('scatter_val = {}'.format(self.scatter_val))
...@@ -115,7 +116,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -115,7 +116,7 @@ class MCMCSearch(core.BaseSearchClass):
logging.info('log10temperature_min = {}'.format( logging.info('log10temperature_min = {}'.format(
self.log10temperature_min)) self.log10temperature_min))
def initiate_search_object(self): def _initiate_search_object(self):
logging.info('Setting up search object') logging.info('Setting up search object')
self.search = core.ComputeFstat( self.search = core.ComputeFstat(
tref=self.tref, sftfilepath=self.sftfilepath, tref=self.tref, sftfilepath=self.sftfilepath,
...@@ -127,7 +128,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -127,7 +128,7 @@ class MCMCSearch(core.BaseSearchClass):
assumeSqrtSX=self.assumeSqrtSX) assumeSqrtSX=self.assumeSqrtSX)
def logp(self, theta_vals, theta_prior, theta_keys, search): def logp(self, theta_vals, theta_prior, theta_keys, search):
H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
zip(theta_vals, theta_keys)] zip(theta_vals, theta_keys)]
return np.sum(H) return np.sum(H)
...@@ -138,7 +139,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -138,7 +139,7 @@ class MCMCSearch(core.BaseSearchClass):
*self.fixed_theta) *self.fixed_theta)
return FS return FS
def unpack_input_theta(self): def _unpack_input_theta(self):
full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta'] full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']
if self.binary: if self.binary:
full_theta_keys += [ full_theta_keys += [
...@@ -179,7 +180,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -179,7 +180,7 @@ class MCMCSearch(core.BaseSearchClass):
self.theta_symbols = [self.theta_symbols[i] for i in idxs] self.theta_symbols = [self.theta_symbols[i] for i in idxs]
self.theta_keys = [self.theta_keys[i] for i in idxs] self.theta_keys = [self.theta_keys[i] for i in idxs]
def check_initial_points(self, p0): def _check_initial_points(self, p0):
for nt in range(self.ntemps): for nt in range(self.ntemps):
logging.info('Checking temperature {} chains'.format(nt)) logging.info('Checking temperature {} chains'.format(nt))
initial_priors = np.array([ initial_priors = np.array([
...@@ -193,10 +194,10 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -193,10 +194,10 @@ class MCMCSearch(core.BaseSearchClass):
.format(len(initial_priors), .format(len(initial_priors),
number_of_initial_out_of_bounds)) number_of_initial_out_of_bounds))
p0 = self.generate_new_p0_to_fix_initial_points( p0 = self._generate_new_p0_to_fix_initial_points(
p0, nt, initial_priors) p0, nt, initial_priors)
def generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors): def _generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors):
logging.info('Attempting to correct intial values') logging.info('Attempting to correct intial values')
idxs = np.arange(self.nwalkers)[initial_priors == -np.inf] idxs = np.arange(self.nwalkers)[initial_priors == -np.inf]
count = 0 count = 0
...@@ -217,7 +218,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -217,7 +218,7 @@ class MCMCSearch(core.BaseSearchClass):
return p0 return p0
def OLD_run_sampler_with_progress_bar(self, sampler, ns, p0): def _OLD_run_sampler_with_progress_bar(self, sampler, ns, p0):
for result in tqdm(sampler.sample(p0, iterations=ns), total=ns): for result in tqdm(sampler.sample(p0, iterations=ns), total=ns):
pass pass
return sampler return sampler
...@@ -271,7 +272,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -271,7 +272,7 @@ class MCMCSearch(core.BaseSearchClass):
self.convergence_number = 0 self.convergence_number = 0
self.convergence_plot_upper_lim = convergence_plot_upper_lim self.convergence_plot_upper_lim = convergence_plot_upper_lim
def get_convergence_statistic(self, i, sampler): def _get_convergence_statistic(self, i, sampler):
s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :] s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :]
within_std = np.mean(np.var(s, axis=1), axis=0) within_std = np.mean(np.var(s, axis=1), axis=0)
per_walker_mean = np.mean(s, axis=1) per_walker_mean = np.mean(s, axis=1)
...@@ -286,25 +287,25 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -286,25 +287,25 @@ class MCMCSearch(core.BaseSearchClass):
self.convergence_diagnosticx.append(i - self.convergence_length/2) self.convergence_diagnosticx.append(i - self.convergence_length/2)
return c return c
def burnin_convergence_test(self, i, sampler, nburn): def _burnin_convergence_test(self, i, sampler, nburn):
if i < self.convergence_burnin_fraction*nburn: if i < self.convergence_burnin_fraction*nburn:
return False return False
if np.mod(i+1, self.convergence_period) != 0: if np.mod(i+1, self.convergence_period) != 0:
return False return False
c = self.get_convergence_statistic(i, sampler) c = self._get_convergence_statistic(i, sampler)
if np.all(c < self.convergence_threshold): if np.all(c < self.convergence_threshold):
self.convergence_number += 1 self.convergence_number += 1
else: else:
self.convergence_number = 0 self.convergence_number = 0
return self.convergence_number > self.convergence_threshold_number return self.convergence_number > self.convergence_threshold_number
def prod_convergence_test(self, i, sampler, nburn): def _prod_convergence_test(self, i, sampler, nburn):
testA = i > nburn + self.convergence_length testA = i > nburn + self.convergence_length
testB = np.mod(i+1, self.convergence_period) == 0 testB = np.mod(i+1, self.convergence_period) == 0
if testA and testB: if testA and testB:
self.get_convergence_statistic(i, sampler) self._get_convergence_statistic(i, sampler)
def check_production_convergence(self, k): def _check_production_convergence(self, k):
bools = np.any( bools = np.any(
np.array(self.convergence_diagnostic)[k:, :] np.array(self.convergence_diagnostic)[k:, :]
> self.convergence_prod_threshold, axis=1) > self.convergence_prod_threshold, axis=1)
...@@ -313,13 +314,13 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -313,13 +314,13 @@ class MCMCSearch(core.BaseSearchClass):
'{} convergence tests in the production run of {} failed' '{} convergence tests in the production run of {} failed'
.format(np.sum(bools), len(bools))) .format(np.sum(bools), len(bools)))
def run_sampler(self, sampler, p0, nprod=0, nburn=0): def _run_sampler(self, sampler, p0, nprod=0, nburn=0):
if hasattr(self, 'convergence_period'): if hasattr(self, 'convergence_period'):
logging.info('Running {} burn-in steps with convergence testing' logging.info('Running {} burn-in steps with convergence testing'
.format(nburn)) .format(nburn))
iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn) iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
for i, output in enumerate(iterator): for i, output in enumerate(iterator):
if self.burnin_convergence_test(i, sampler, nburn): if self._burnin_convergence_test(i, sampler, nburn):
logging.info( logging.info(
'Converged at {} before max number {} of steps reached' 'Converged at {} before max number {} of steps reached'
.format(i, nburn)) .format(i, nburn))
...@@ -331,9 +332,9 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -331,9 +332,9 @@ class MCMCSearch(core.BaseSearchClass):
k = len(self.convergence_diagnostic) k = len(self.convergence_diagnostic)
for result in tqdm(sampler.sample(output[0], iterations=nprod), for result in tqdm(sampler.sample(output[0], iterations=nprod),
total=nprod): total=nprod):
self.prod_convergence_test(j, sampler, nburn) self._prod_convergence_test(j, sampler, nburn)
j += 1 j += 1
self.check_production_convergence(k) self._check_production_convergence(k)
return sampler return sampler
else: else:
for result in tqdm(sampler.sample(p0, iterations=nburn+nprod), for result in tqdm(sampler.sample(p0, iterations=nburn+nprod),
...@@ -342,50 +343,51 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -342,50 +343,51 @@ class MCMCSearch(core.BaseSearchClass):
return sampler return sampler
def run(self, proposal_scale_factor=2, create_plots=True, **kwargs): def run(self, proposal_scale_factor=2, create_plots=True, **kwargs):
""" Run the MCMC simulatation """
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use() self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use()
if self.old_data_is_okay_to_use is True: if self.old_data_is_okay_to_use is True:
logging.warning('Using saved data from {}'.format( logging.warning('Using saved data from {}'.format(
self.pickle_path)) self.pickle_path))
d = self.get_saved_data() d = self.get_saved_data_dictionary()
self.sampler = d['sampler'] self.sampler = d['sampler']
self.samples = d['samples'] self.samples = d['samples']
self.lnprobs = d['lnprobs'] self.lnprobs = d['lnprobs']
self.lnlikes = d['lnlikes'] self.lnlikes = d['lnlikes']
return return
self.initiate_search_object() self._initiate_search_object()
sampler = emcee.PTSampler( sampler = emcee.PTSampler(
self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp, self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
logpargs=(self.theta_prior, self.theta_keys, self.search), logpargs=(self.theta_prior, self.theta_keys, self.search),
loglargs=(self.search,), betas=self.betas, a=proposal_scale_factor) loglargs=(self.search,), betas=self.betas, a=proposal_scale_factor)
p0 = self.generate_initial_p0() p0 = self._generate_initial_p0()
p0 = self.apply_corrections_to_p0(p0) p0 = self._apply_corrections_to_p0(p0)
self.check_initial_points(p0) self._check_initial_points(p0)
ninit_steps = len(self.nsteps) - 2 ninit_steps = len(self.nsteps) - 2
for j, n in enumerate(self.nsteps[:-2]): for j, n in enumerate(self.nsteps[:-2]):
logging.info('Running {}/{} initialisation with {} steps'.format( logging.info('Running {}/{} initialisation with {} steps'.format(
j, ninit_steps, n)) j, ninit_steps, n))
sampler = self.run_sampler(sampler, p0, nburn=n) sampler = self._run_sampler(sampler, p0, nburn=n)
logging.info("Mean acceptance fraction: {}" logging.info("Mean acceptance fraction: {}"
.format(np.mean(sampler.acceptance_fraction, axis=1))) .format(np.mean(sampler.acceptance_fraction, axis=1)))
if self.ntemps > 1: if self.ntemps > 1:
logging.info("Tswap acceptance fraction: {}" logging.info("Tswap acceptance fraction: {}"
.format(sampler.tswap_acceptance_fraction)) .format(sampler.tswap_acceptance_fraction))
if create_plots: if create_plots:
fig, axes = self.plot_walkers(sampler, fig, axes = self._plot_walkers(sampler,
symbols=self.theta_symbols, symbols=self.theta_symbols,
**kwargs) **kwargs)
fig.tight_layout() fig.tight_layout()
fig.savefig('{}/{}_init_{}_walkers.png'.format( fig.savefig('{}/{}_init_{}_walkers.png'.format(
self.outdir, self.label, j), dpi=400) self.outdir, self.label, j), dpi=400)
p0 = self.get_new_p0(sampler) p0 = self._get_new_p0(sampler)
p0 = self.apply_corrections_to_p0(p0) p0 = self._apply_corrections_to_p0(p0)
self.check_initial_points(p0) self._check_initial_points(p0)
sampler.reset() sampler.reset()
if len(self.nsteps) > 1: if len(self.nsteps) > 1:
...@@ -395,7 +397,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -395,7 +397,7 @@ class MCMCSearch(core.BaseSearchClass):
nprod = self.nsteps[-1] nprod = self.nsteps[-1]
logging.info('Running final burn and prod with {} steps'.format( logging.info('Running final burn and prod with {} steps'.format(
nburn+nprod)) nburn+nprod))
sampler = self.run_sampler(sampler, p0, nburn=nburn, nprod=nprod) sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
logging.info("Mean acceptance fraction: {}" logging.info("Mean acceptance fraction: {}"
.format(np.mean(sampler.acceptance_fraction, axis=1))) .format(np.mean(sampler.acceptance_fraction, axis=1)))
if self.ntemps > 1: if self.ntemps > 1:
...@@ -403,7 +405,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -403,7 +405,7 @@ class MCMCSearch(core.BaseSearchClass):
.format(sampler.tswap_acceptance_fraction)) .format(sampler.tswap_acceptance_fraction))
if create_plots: if create_plots:
fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols, fig, axes = self._plot_walkers(sampler, symbols=self.theta_symbols,
nprod=nprod, **kwargs) nprod=nprod, **kwargs)
fig.tight_layout() fig.tight_layout()
fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label), fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
...@@ -416,9 +418,9 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -416,9 +418,9 @@ class MCMCSearch(core.BaseSearchClass):
self.samples = samples self.samples = samples
self.lnprobs = lnprobs self.lnprobs = lnprobs
self.lnlikes = lnlikes self.lnlikes = lnlikes
self.save_data(sampler, samples, lnprobs, lnlikes) self._save_data(sampler, samples, lnprobs, lnlikes)
def get_rescale_multiplier_for_key(self, key): def _get_rescale_multiplier_for_key(self, key):
""" Get the rescale multiplier from the rescale_dictionary """ Get the rescale multiplier from the rescale_dictionary
Can either be a float, a string (in which case it is interpretted as Can either be a float, a string (in which case it is interpretted as
...@@ -443,7 +445,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -443,7 +445,7 @@ class MCMCSearch(core.BaseSearchClass):
multiplier = 1 multiplier = 1
return multiplier return multiplier
def get_rescale_subtractor_for_key(self, key): def _get_rescale_subtractor_for_key(self, key):
""" Get the rescale subtractor from the rescale_dictionary """ Get the rescale subtractor from the rescale_dictionary
Can either be a float, a string (in which case it is interpretted as Can either be a float, a string (in which case it is interpretted as
...@@ -468,21 +470,21 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -468,21 +470,21 @@ class MCMCSearch(core.BaseSearchClass):
subtractor = 0 subtractor = 0
return subtractor return subtractor
def scale_samples(self, samples, theta_keys): def _scale_samples(self, samples, theta_keys):
""" Scale the samples using the rescale_dictionary """ """ Scale the samples using the rescale_dictionary """
for key in theta_keys: for key in theta_keys:
if key in self.rescale_dictionary: if key in self.rescale_dictionary:
idx = theta_keys.index(key) idx = theta_keys.index(key)
s = samples[:, idx] s = samples[:, idx]
subtractor = self.get_rescale_subtractor_for_key(key) subtractor = self._get_rescale_subtractor_for_key(key)
s = s - subtractor s = s - subtractor
multiplier = self.get_rescale_multiplier_for_key(key) multiplier = self._get_rescale_multiplier_for_key(key)
s *= multiplier s *= multiplier
samples[:, idx] = s samples[:, idx] = s
return samples return samples
def get_labels(self): def _get_labels(self):