Skip to content
Snippets Groups Projects
Commit 5bd1af86 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adds ability to input initial as list and improves logging

parent e3b2d3eb
Branches
Tags
No related merge requests found
...@@ -800,11 +800,21 @@ class MCMCSearch(BaseSearchClass): ...@@ -800,11 +800,21 @@ class MCMCSearch(BaseSearchClass):
""" Generate a set of init vals for the walkers """ """ Generate a set of init vals for the walkers """
if type(self.theta_initial) == dict: if type(self.theta_initial) == dict:
logging.info('Generate initial values from initial dictionary')
if self.nglitch > 1:
raise ValueError('Initial dict not implemented for nglitch>1')
p0 = [[[self.generate_rv(**self.theta_initial[key]) p0 = [[[self.generate_rv(**self.theta_initial[key])
for key in self.theta_keys] for key in self.theta_keys]
for i in range(self.nwalkers)] for i in range(self.nwalkers)]
for j in range(self.ntemps)] for j in range(self.ntemps)]
elif type(self.theta_initial) == list:
logging.info('Generate initial values from list of theta_initial')
p0 = [[[self.generate_rv(**val)
for val in self.theta_initial]
for i in range(self.nwalkers)]
for j in range(self.ntemps)]
elif self.theta_initial is None: elif self.theta_initial is None:
logging.info('Generate initial values from prior dictionary')
p0 = [[[self.generate_rv(**self.theta_prior[key]) p0 = [[[self.generate_rv(**self.theta_prior[key])
for key in self.theta_keys] for key in self.theta_keys]
for i in range(self.nwalkers)] for i in range(self.nwalkers)]
...@@ -958,12 +968,15 @@ class MCMCSearch(BaseSearchClass): ...@@ -958,12 +968,15 @@ class MCMCSearch(BaseSearchClass):
d = OrderedDict() d = OrderedDict()
lnl_finite = copy.copy(self.lnlikes) lnl_finite = copy.copy(self.lnlikes)
lnl_finite[idxs] = np.nan lnl_finite[idxs] = 0
close_idxs = abs((maxtwoF - lnl_finite) / maxtwoF) < threshold close_idxs = abs((maxtwoF - lnl_finite) / maxtwoF) < threshold
for i, k in enumerate(self.theta_keys): for i, k in enumerate(self.theta_keys):
ng = 1 ng = 1
while k in d: while k in d:
k = k.rstrip('_{}'.format(ng-1)) + '_{}'.format(ng) if k == 1:
k = k + '_1'
else:
k.replace('_{}'.format(ng-1), '_{}'.format(ng))
ng += 1 ng += 1
d[k] = self.samples[jmax][i] d[k] = self.samples[jmax][i]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment