From d7f2a89f1e8cc59a24961cc90e7a94875ae6f887 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Fri, 7 Oct 2016 15:48:38 +0200
Subject: [PATCH] Adds ability to 2D plot data in higher dimension

---
 pyfstat.py | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/pyfstat.py b/pyfstat.py
index 256c9bd..bb5fb3c 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -1487,7 +1487,7 @@ class GridSearch(BaseSearchClass):
         fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
 
     def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None,
-                add_mismatch=None, xN=None, yN=None):
+                add_mismatch=None, xN=None, yN=None, flat_keys=[]):
         """ Plots a 2D grid of 2F values
 
         Parameters
@@ -1500,12 +1500,19 @@ class GridSearch(BaseSearchClass):
             fig, ax = plt.subplots()
         xidx = self.keys.index(xkey)
         yidx = self.keys.index(ykey)
+        flat_idxs = [self.keys.index(k) for k in flat_keys]
+
         x = np.unique(self.data[:, xidx])
         y = np.unique(self.data[:, yidx])
+        flat_vals = [np.unique(self.data[:, j]) for j in flat_idxs]
         z = self.data[:, -1]
 
         Y, X = np.meshgrid(y, x)
-        Z = z.reshape(X.shape)
+        shape = [len(x), len(y)] + [len(v) for v in flat_vals]
+        Z = z.reshape(shape)
+
+        while Z.ndim > 2:
+            Z = np.mean(Z, axis=-1)
 
         pax = ax.pcolormesh(X, Y, Z, cmap=plt.cm.viridis, vmin=vmin, vmax=vmax)
         plt.colorbar(pax, ax=ax)
-- 
GitLab