diff --git a/examples/transient_examples/transient_search_using_MCMC.py b/examples/transient_examples/transient_search_using_MCMC.py index b70dc3d85c1d17d5ab5906887dfabc4184922f5c..ec5345efa35b1d7342a7eefa00115867db625373 100644 --- a/examples/transient_examples/transient_search_using_MCMC.py +++ b/examples/transient_examples/transient_search_using_MCMC.py @@ -25,9 +25,7 @@ theta_prior = {'F0': {'type': 'unif', 'F2': F2, 'Alpha': Alpha, 'Delta': Delta, - 'transient_tstart': {'type': 'unif', - 'lower': minStartTime, - 'upper': maxStartTime}, + 'transient_tstart': minStartTime, 'transient_duration': {'type': 'halfnorm', 'loc': 0.001*Tspan, 'scale': 0.5*Tspan} diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index 912445e33f2e2846710b8edaf70df7490f262144..fa13b688b96061e5e3d00f671e777763f127a4b2 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -212,38 +212,39 @@ class MCMCSearch(core.BaseSearchClass): self.theta_symbols = [self.theta_symbols[i] for i in idxs] self.theta_keys = [self.theta_keys[i] for i in idxs] + def _evaluate_logpost(self, p0vec): + init_logp = np.array([ + self.logp(p, self.theta_prior, self.theta_keys, self.search) + for p in p0vec]) + init_logl = np.array([ + self.logl(p, self.search) + for p in p0vec]) + return init_logl + init_logp + def _check_initial_points(self, p0): for nt in range(self.ntemps): logging.info('Checking temperature {} chains'.format(nt)) - initial_priors = np.array([ - self.logp(p, self.theta_prior, self.theta_keys, self.search) - for p in p0[nt]]) - number_of_initial_out_of_bounds = sum(initial_priors == -np.inf) - - if number_of_initial_out_of_bounds > 0: + num = sum(self._evaluate_logpost(p0[nt]) == -np.inf) + if num > 0: logging.warning( 'Of {} initial values, {} are -np.inf due to the prior' - .format(len(initial_priors), - number_of_initial_out_of_bounds)) - + .format(len(p0[0]), num)) p0 = self._generate_new_p0_to_fix_initial_points( - p0, nt, initial_priors) + p0, nt) - def _generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors): + def _generate_new_p0_to_fix_initial_points(self, p0, nt): logging.info('Attempting to correct intial values') - idxs = np.arange(self.nwalkers)[initial_priors == -np.inf] + init_logpost = self._evaluate_logpost(p0[nt]) + idxs = np.arange(self.nwalkers)[init_logpost == -np.inf] count = 0 - while sum(initial_priors == -np.inf) > 0 and count < 100: + while sum(init_logpost == -np.inf) > 0 and count < 100: for j in idxs: p0[nt][j] = (p0[nt][np.random.randint(0, self.nwalkers)]*( 1+np.random.normal(0, 1e-10, self.ndim))) - initial_priors = np.array([ - self.logp(p, self.theta_prior, self.theta_keys, - self.search) - for p in p0[nt]]) + init_logpost = self._evaluate_logpost(p0[nt]) count += 1 - if sum(initial_priors == -np.inf) > 0: + if sum(init_logpost == -np.inf) > 0: logging.info('Failed to fix initial priors') else: logging.info('Suceeded to fix initial priors')