diff --git a/pyfstat.py b/pyfstat.py index c26a5dc2ea4543529b1e82691a00459cf02b010a..b76daf309f261a60d171ff87e8685cebcb5a10a0 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -543,14 +543,42 @@ class MCMCSearch(BaseSearchClass): self.theta_keys = [self.theta_keys[i] for i in idxs] def check_initial_points(self, p0): - initial_priors = np.array([ - self.logp(p, self.theta_prior, self.theta_keys, self.search) - for p in p0[0]]) - number_of_initial_out_of_bounds = sum(initial_priors == -np.inf) - if number_of_initial_out_of_bounds > 0: - logging.warning( - 'Of {} initial values, {} are -np.inf due to the prior'.format( - len(initial_priors), number_of_initial_out_of_bounds)) + for nt in range(self.ntemps): + logging.info('Checking temperature {} chains'.format(nt)) + initial_priors = np.array([ + self.logp(p, self.theta_prior, self.theta_keys, self.search) + for p in p0[nt]]) + number_of_initial_out_of_bounds = sum(initial_priors == -np.inf) + + if number_of_initial_out_of_bounds > 0: + logging.warning( + 'Of {} initial values, {} are -np.inf due to the prior' + .format(len(initial_priors), + number_of_initial_out_of_bounds)) + + p0 = self.generate_new_p0_to_fix_initial_points( + p0, nt, initial_priors) + + def generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors): + logging.info('Attempting to correct intial values') + idxs = np.arange(self.nwalkers)[initial_priors == -np.inf] + count = 0 + while sum(initial_priors == -np.inf) > 0 and count < 100: + for j in idxs: + p0[nt][j] = (p0[nt][np.random.randint(0, self.nwalkers)]*( + 1+np.random.normal(0, 1e-10, self.ndim))) + initial_priors = np.array([ + self.logp(p, self.theta_prior, self.theta_keys, + self.search) + for p in p0[nt]]) + count += 1 + + if sum(initial_priors == -np.inf) > 0: + logging.info('Failed to fix initial priors') + else: + logging.info('Suceeded to fix initial priors') + + return p0 def run(self):