diff --git a/gridcorner.py b/gridcorner.py index 0d9dc264ce6d15e93ed43367182ad7ea0bb77a28..e8c3a08425b3a279a7065573e34104ed643f3b12 100644 --- a/gridcorner.py +++ b/gridcorner.py @@ -48,7 +48,7 @@ def idx_array_slice(D, axis, slice_idx): def gridcorner(D, xyz, labels=None, projection='max_slice', max_n_ticks=4, - factor=3, whspace=0.05, showDvals=True, **kwargs): + factor=3, whspace=0.05, showDvals=True, lines=None, **kwargs): """ Generate a grid corner plot Parameters @@ -75,6 +75,8 @@ def gridcorner(D, xyz, labels=None, projection='max_slice', max_n_ticks=4, showDvals: bool If true (default) show the D values on the right-hand-side of the 1D plots and add a label. + lines: array_like + N-dimensional list of values to delineate. Returns ------- @@ -106,7 +108,7 @@ def gridcorner(D, xyz, labels=None, projection='max_slice', max_n_ticks=4, for i in range(ndim): projection_1D( axes[i, i], xyz[i], D, i, projection=projection, - showDvals=showDvals, **kwargs) + showDvals=showDvals, lines=lines, **kwargs) for j in range(ndim): ax = axes[i, j] @@ -129,7 +131,7 @@ def gridcorner(D, xyz, labels=None, projection='max_slice', max_n_ticks=4, ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper")) ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper")) - ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j, + ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j, lines=lines, projection=projection, **kwargs) if labels: @@ -142,27 +144,32 @@ def gridcorner(D, xyz, labels=None, projection='max_slice', max_n_ticks=4, return fig, axes -def projection_2D(ax, x, y, D, xidx, yidx, projection, **kwargs): +def projection_2D(ax, x, y, D, xidx, yidx, projection, lines=None, **kwargs): flat_idxs = range(D.ndim) flat_idxs.remove(xidx) flat_idxs.remove(yidx) D2D = projection(D, axis=tuple(flat_idxs), **kwargs) X, Y = np.meshgrid(x, y, indexing='ij') pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max()) + if lines: + ax.axhline(lines[xidx], lw=0.5, color='w') + ax.axvline(lines[yidx], lw=0.5, color='w') return ax, pax -def projection_1D(ax, x, D, xidx, projection, showDvals=True, +def projection_1D(ax, x, D, xidx, projection, showDvals=True, lines=None, **kwargs): flat_idxs = range(D.ndim) flat_idxs.remove(xidx) D1D = projection(D, axis=tuple(flat_idxs), **kwargs) - ax.plot(x, D1D) + ax.plot(x, D1D, color='k') if showDvals: ax.yaxis.tick_right() ax.yaxis.set_label_position("right") else: ax.yaxis.set_ticklabels([]) + if lines: + ax.axvline(lines[xidx], lw=0.5, color='C0') return ax