diff --git a/pyfstat.py b/pyfstat.py index 64642db2f49ccb817ab1463e0676df200da52510..322a0763d855dc31b409ee7b88276cf8c9ae3bdd 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -1599,8 +1599,8 @@ class GridSearch(BaseSearchClass): fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, - add_mismatch=None, xN=None, yN=None, flat_keys=[], - flatten_method=np.max): + add_mismatch=None, xN=None, yN=None, flat_keys=[], + rel_flat_idxs=[], flatten_method=np.max): """ Plots a 2D grid of 2F values Parameters @@ -1626,8 +1626,8 @@ class GridSearch(BaseSearchClass): shape = [len(x), len(y)] + [len(v) for v in flat_vals] Z = z.reshape(shape) - while Z.ndim > 2: - Z = flatten_method(Z, axis=-1) + if len(rel_flat_idxs) > 0: + Z = flatten_method(Z, axis=tuple(rel_flat_idxs)) pax = ax.pcolormesh(X, Y, Z, cmap=plt.cm.viridis, vmin=vmin, vmax=vmax) plt.colorbar(pax, ax=ax)