diff --git a/pyfstat.py b/pyfstat.py index dea022d2137808eb3a67bca851a84b7e903b7faa..c3a9adf095a4989a1e1cc50fe104d33334ad2d8b 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -1607,7 +1607,8 @@ class GridSearch(BaseSearchClass): def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, add_mismatch=None, xN=None, yN=None, flat_keys=[], - rel_flat_idxs=[], flatten_method=np.max): + rel_flat_idxs=[], flatten_method=np.max, + predicted_twoF=None, cm=None): """ Plots a 2D grid of 2F values Parameters @@ -1636,7 +1637,15 @@ class GridSearch(BaseSearchClass): if len(rel_flat_idxs) > 0: Z = flatten_method(Z, axis=tuple(rel_flat_idxs)) - pax = ax.pcolormesh(X, Y, Z, cmap=plt.cm.viridis, vmin=vmin, vmax=vmax) + if predicted_twoF: + Z = (predicted_twoF - Z) / (predicted_twoF + 4) + if cm is None: + cm = plt.cm.viridis_r + else: + if cm is None: + cm = plt.cm.viridis + + pax = ax.pcolormesh(X, Y, Z, cmap=cm, vmin=vmin, vmax=vmax) plt.colorbar(pax, ax=ax) if add_mismatch: