diff --git a/pyfstat.py b/pyfstat.py
index decacc5176e5ce87317f755b89efe7303f62e2e8..7e92fd26673b8fb4a79f7944ae17e5aaf46a1d19 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -984,36 +984,32 @@ class MCMCSearch(BaseSearchClass):
         the maximum posterior with scale `scatter_val`.
 
         """
-        if sampler.chain[:, :, -1, :].shape[0] == 1:
-            ntemps_temp = 1
-        else:
-            ntemps_temp = self.ntemps
-        pF = sampler.chain[:, :, -1, :].reshape(
-            ntemps_temp, self.nwalkers, self.ndim)[0, :, :]
-        lnl = sampler.lnlikelihood[:, :, -1].reshape(
-            self.ntemps, self.nwalkers)[0, :]
-        lnp = sampler.lnprobability[:, :, -1].reshape(
-            self.ntemps, self.nwalkers)[0, :]
+        temp_idx = 0
+        pF = sampler.chain[temp_idx, :, :, :]
+        lnl = sampler.lnlikelihood[temp_idx, :, :]
+        lnp = sampler.lnprobability[temp_idx, :, :]
 
         # General warnings about the state of lnp
-        if any(np.isnan(lnp)):
+        if np.any(np.isnan(lnp)):
             logging.warning(
                 "Of {} lnprobs {} are nan".format(
-                    len(lnp), np.sum(np.isnan(lnp))))
-        if any(np.isposinf(lnp)):
+                    np.shape(lnp), np.sum(np.isnan(lnp))))
+        if np.any(np.isposinf(lnp)):
             logging.warning(
                 "Of {} lnprobs {} are +np.inf".format(
-                    len(lnp), np.sum(np.isposinf(lnp))))
-        if any(np.isneginf(lnp)):
+                    np.shape(lnp), np.sum(np.isposinf(lnp))))
+        if np.any(np.isneginf(lnp)):
             logging.warning(
                 "Of {} lnprobs {} are -np.inf".format(
-                    len(lnp), np.sum(np.isneginf(lnp))))
+                    np.shape(lnp), np.sum(np.isneginf(lnp))))
 
         lnp_finite = copy.copy(lnp)
         lnp_finite[np.isinf(lnp)] = np.nan
-        p = pF[np.nanargmax(lnp_finite)]
-        logging.info('Generating new p0 from max lnp which had twoF={}'
-                     .format(lnl[np.nanargmax(lnp_finite)]))
+        idx = np.unravel_index(np.nanargmax(lnp_finite), lnp_finite.shape)
+        logging.info(('Gen. new p0 from max lnp (walker {}, pos {})'
+                      ' which had twoF={} ')
+                     .format(idx[0], idx[1], lnl[idx]))
+        p = pF[idx]
         p0 = self.generate_scattered_p0(p)
 
         return p0