diff --git a/pyfstat.py b/pyfstat.py
index c26a5dc2ea4543529b1e82691a00459cf02b010a..b76daf309f261a60d171ff87e8685cebcb5a10a0 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -543,14 +543,42 @@ class MCMCSearch(BaseSearchClass):
         self.theta_keys = [self.theta_keys[i] for i in idxs]
 
     def check_initial_points(self, p0):
-        initial_priors = np.array([
-            self.logp(p, self.theta_prior, self.theta_keys, self.search)
-            for p in p0[0]])
-        number_of_initial_out_of_bounds = sum(initial_priors == -np.inf)
-        if number_of_initial_out_of_bounds > 0:
-            logging.warning(
-                'Of {} initial values, {} are -np.inf due to the prior'.format(
-                    len(initial_priors), number_of_initial_out_of_bounds))
+        for nt in range(self.ntemps):
+            logging.info('Checking temperature {} chains'.format(nt))
+            initial_priors = np.array([
+                self.logp(p, self.theta_prior, self.theta_keys, self.search)
+                for p in p0[nt]])
+            number_of_initial_out_of_bounds = sum(initial_priors == -np.inf)
+
+            if number_of_initial_out_of_bounds > 0:
+                logging.warning(
+                    'Of {} initial values, {} are -np.inf due to the prior'
+                    .format(len(initial_priors),
+                            number_of_initial_out_of_bounds))
+
+                p0 = self.generate_new_p0_to_fix_initial_points(
+                    p0, nt, initial_priors)
+
+    def generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors):
+        logging.info('Attempting to correct intial values')
+        idxs = np.arange(self.nwalkers)[initial_priors == -np.inf]
+        count = 0
+        while sum(initial_priors == -np.inf) > 0 and count < 100:
+            for j in idxs:
+                p0[nt][j] = (p0[nt][np.random.randint(0, self.nwalkers)]*(
+                             1+np.random.normal(0, 1e-10, self.ndim)))
+            initial_priors = np.array([
+                self.logp(p, self.theta_prior, self.theta_keys,
+                          self.search)
+                for p in p0[nt]])
+            count += 1
+
+        if sum(initial_priors == -np.inf) > 0:
+            logging.info('Failed to fix initial priors')
+        else:
+            logging.info('Suceeded to fix initial priors')
+
+        return p0
 
     def run(self):