From 25e5f1ce359c39d6c9a99fd047177de2973edafc Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Tue, 27 Sep 2016 12:12:31 +0200
Subject: [PATCH] Adds process to 'fix' initial p0

If the initial p0 values are -np.inf due to the prior, it will iterate
through the other p0 values to replace the duff values. The replacement
takes another randomly selected p0 value and shifts it round by a
fractional gaussian with std=1e-10
---
 pyfstat.py | 44 ++++++++++++++++++++++++++++++++++++--------
 1 file changed, 36 insertions(+), 8 deletions(-)

diff --git a/pyfstat.py b/pyfstat.py
index c26a5dc..b76daf3 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -543,14 +543,42 @@ class MCMCSearch(BaseSearchClass):
         self.theta_keys = [self.theta_keys[i] for i in idxs]
 
     def check_initial_points(self, p0):
-        initial_priors = np.array([
-            self.logp(p, self.theta_prior, self.theta_keys, self.search)
-            for p in p0[0]])
-        number_of_initial_out_of_bounds = sum(initial_priors == -np.inf)
-        if number_of_initial_out_of_bounds > 0:
-            logging.warning(
-                'Of {} initial values, {} are -np.inf due to the prior'.format(
-                    len(initial_priors), number_of_initial_out_of_bounds))
+        for nt in range(self.ntemps):
+            logging.info('Checking temperature {} chains'.format(nt))
+            initial_priors = np.array([
+                self.logp(p, self.theta_prior, self.theta_keys, self.search)
+                for p in p0[nt]])
+            number_of_initial_out_of_bounds = sum(initial_priors == -np.inf)
+
+            if number_of_initial_out_of_bounds > 0:
+                logging.warning(
+                    'Of {} initial values, {} are -np.inf due to the prior'
+                    .format(len(initial_priors),
+                            number_of_initial_out_of_bounds))
+
+                p0 = self.generate_new_p0_to_fix_initial_points(
+                    p0, nt, initial_priors)
+
+    def generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors):
+        logging.info('Attempting to correct intial values')
+        idxs = np.arange(self.nwalkers)[initial_priors == -np.inf]
+        count = 0
+        while sum(initial_priors == -np.inf) > 0 and count < 100:
+            for j in idxs:
+                p0[nt][j] = (p0[nt][np.random.randint(0, self.nwalkers)]*(
+                             1+np.random.normal(0, 1e-10, self.ndim)))
+            initial_priors = np.array([
+                self.logp(p, self.theta_prior, self.theta_keys,
+                          self.search)
+                for p in p0[nt]])
+            count += 1
+
+        if sum(initial_priors == -np.inf) > 0:
+            logging.info('Failed to fix initial priors')
+        else:
+            logging.info('Suceeded to fix initial priors')
+
+        return p0
 
     def run(self):
 
-- 
GitLab