diff --git a/pyfstat.py b/pyfstat.py index 256c9bd62bf59234404eb7f7140d992b884dfd25..bb5fb3c1c7b1bafe83c78e48ce8c9d335730f184 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -1487,7 +1487,7 @@ class GridSearch(BaseSearchClass): fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, - add_mismatch=None, xN=None, yN=None): + add_mismatch=None, xN=None, yN=None, flat_keys=[]): """ Plots a 2D grid of 2F values Parameters @@ -1500,12 +1500,19 @@ class GridSearch(BaseSearchClass): fig, ax = plt.subplots() xidx = self.keys.index(xkey) yidx = self.keys.index(ykey) + flat_idxs = [self.keys.index(k) for k in flat_keys] + x = np.unique(self.data[:, xidx]) y = np.unique(self.data[:, yidx]) + flat_vals = [np.unique(self.data[:, j]) for j in flat_idxs] z = self.data[:, -1] Y, X = np.meshgrid(y, x) - Z = z.reshape(X.shape) + shape = [len(x), len(y)] + [len(v) for v in flat_vals] + Z = z.reshape(shape) + + while Z.ndim > 2: + Z = np.mean(Z, axis=-1) pax = ax.pcolormesh(X, Y, Z, cmap=plt.cm.viridis, vmin=vmin, vmax=vmax) plt.colorbar(pax, ax=ax)