diff --git a/examples/transient_examples/transient_search_using_MCMC.py b/examples/transient_examples/transient_search_using_MCMC.py
index b70dc3d85c1d17d5ab5906887dfabc4184922f5c..ec5345efa35b1d7342a7eefa00115867db625373 100644
--- a/examples/transient_examples/transient_search_using_MCMC.py
+++ b/examples/transient_examples/transient_search_using_MCMC.py
@@ -25,9 +25,7 @@ theta_prior = {'F0': {'type': 'unif',
                'F2': F2,
                'Alpha': Alpha,
                'Delta': Delta,
-               'transient_tstart': {'type': 'unif',
-                                    'lower': minStartTime,
-                                    'upper': maxStartTime},
+               'transient_tstart': minStartTime,
                'transient_duration': {'type': 'halfnorm',
                                       'loc': 0.001*Tspan,
                                       'scale': 0.5*Tspan}
diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index 912445e33f2e2846710b8edaf70df7490f262144..fa13b688b96061e5e3d00f671e777763f127a4b2 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -212,38 +212,39 @@ class MCMCSearch(core.BaseSearchClass):
         self.theta_symbols = [self.theta_symbols[i] for i in idxs]
         self.theta_keys = [self.theta_keys[i] for i in idxs]
 
+    def _evaluate_logpost(self, p0vec):
+        init_logp = np.array([
+            self.logp(p, self.theta_prior, self.theta_keys, self.search)
+            for p in p0vec])
+        init_logl = np.array([
+            self.logl(p, self.search)
+            for p in p0vec])
+        return init_logl + init_logp
+
     def _check_initial_points(self, p0):
         for nt in range(self.ntemps):
             logging.info('Checking temperature {} chains'.format(nt))
-            initial_priors = np.array([
-                self.logp(p, self.theta_prior, self.theta_keys, self.search)
-                for p in p0[nt]])
-            number_of_initial_out_of_bounds = sum(initial_priors == -np.inf)
-
-            if number_of_initial_out_of_bounds > 0:
+            num = sum(self._evaluate_logpost(p0[nt]) == -np.inf)
+            if num > 0:
                 logging.warning(
                     'Of {} initial values, {} are -np.inf due to the prior'
-                    .format(len(initial_priors),
-                            number_of_initial_out_of_bounds))
-
+                    .format(len(p0[0]), num))
                 p0 = self._generate_new_p0_to_fix_initial_points(
-                    p0, nt, initial_priors)
+                    p0, nt)
 
-    def _generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors):
+    def _generate_new_p0_to_fix_initial_points(self, p0, nt):
         logging.info('Attempting to correct intial values')
-        idxs = np.arange(self.nwalkers)[initial_priors == -np.inf]
+        init_logpost = self._evaluate_logpost(p0[nt])
+        idxs = np.arange(self.nwalkers)[init_logpost == -np.inf]
         count = 0
-        while sum(initial_priors == -np.inf) > 0 and count < 100:
+        while sum(init_logpost == -np.inf) > 0 and count < 100:
             for j in idxs:
                 p0[nt][j] = (p0[nt][np.random.randint(0, self.nwalkers)]*(
                              1+np.random.normal(0, 1e-10, self.ndim)))
-            initial_priors = np.array([
-                self.logp(p, self.theta_prior, self.theta_keys,
-                          self.search)
-                for p in p0[nt]])
+            init_logpost = self._evaluate_logpost(p0[nt])
             count += 1
 
-        if sum(initial_priors == -np.inf) > 0:
+        if sum(init_logpost == -np.inf) > 0:
             logging.info('Failed to fix initial priors')
         else:
             logging.info('Suceeded to fix initial priors')