From a94d215ea7ef9fbca52629985d335766a43fabd0 Mon Sep 17 00:00:00 2001
From: "gregory.ashton" <gregory.ashton@ligo.org>
Date: Tue, 11 Oct 2016 11:42:30 +0200
Subject: [PATCH] Improves pruning method

If an initialisation step is callled, the sampler runs the simulations
for nsteps and selects the point with the highest lnprobability from
which to restart the simulations. Previously, only the last step was
used. In this update all steps are used and a message is printed to
allow the user to determine if the selected point is appropriate.
---
 pyfstat.py | 34 +++++++++++++++-------------------
 1 file changed, 15 insertions(+), 19 deletions(-)

diff --git a/pyfstat.py b/pyfstat.py
index decacc5..7e92fd2 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -984,36 +984,32 @@ class MCMCSearch(BaseSearchClass):
         the maximum posterior with scale `scatter_val`.
 
         """
-        if sampler.chain[:, :, -1, :].shape[0] == 1:
-            ntemps_temp = 1
-        else:
-            ntemps_temp = self.ntemps
-        pF = sampler.chain[:, :, -1, :].reshape(
-            ntemps_temp, self.nwalkers, self.ndim)[0, :, :]
-        lnl = sampler.lnlikelihood[:, :, -1].reshape(
-            self.ntemps, self.nwalkers)[0, :]
-        lnp = sampler.lnprobability[:, :, -1].reshape(
-            self.ntemps, self.nwalkers)[0, :]
+        temp_idx = 0
+        pF = sampler.chain[temp_idx, :, :, :]
+        lnl = sampler.lnlikelihood[temp_idx, :, :]
+        lnp = sampler.lnprobability[temp_idx, :, :]
 
         # General warnings about the state of lnp
-        if any(np.isnan(lnp)):
+        if np.any(np.isnan(lnp)):
             logging.warning(
                 "Of {} lnprobs {} are nan".format(
-                    len(lnp), np.sum(np.isnan(lnp))))
-        if any(np.isposinf(lnp)):
+                    np.shape(lnp), np.sum(np.isnan(lnp))))
+        if np.any(np.isposinf(lnp)):
             logging.warning(
                 "Of {} lnprobs {} are +np.inf".format(
-                    len(lnp), np.sum(np.isposinf(lnp))))
-        if any(np.isneginf(lnp)):
+                    np.shape(lnp), np.sum(np.isposinf(lnp))))
+        if np.any(np.isneginf(lnp)):
             logging.warning(
                 "Of {} lnprobs {} are -np.inf".format(
-                    len(lnp), np.sum(np.isneginf(lnp))))
+                    np.shape(lnp), np.sum(np.isneginf(lnp))))
 
         lnp_finite = copy.copy(lnp)
         lnp_finite[np.isinf(lnp)] = np.nan
-        p = pF[np.nanargmax(lnp_finite)]
-        logging.info('Generating new p0 from max lnp which had twoF={}'
-                     .format(lnl[np.nanargmax(lnp_finite)]))
+        idx = np.unravel_index(np.nanargmax(lnp_finite), lnp_finite.shape)
+        logging.info(('Gen. new p0 from max lnp (walker {}, pos {})'
+                      ' which had twoF={} ')
+                     .format(idx[0], idx[1], lnl[idx]))
+        p = pF[idx]
         p0 = self.generate_scattered_p0(p)
 
         return p0
-- 
GitLab