Commit c7830615 authored by Gregory Ashton's avatar Gregory Ashton

Add 'lines' - i.e. like the truths in Corner.py

parent 2fcc71c4
......@@ -48,7 +48,7 @@ def idx_array_slice(D, axis, slice_idx):
def gridcorner(D, xyz, labels=None, projection='max_slice', max_n_ticks=4,
factor=3, whspace=0.05, showDvals=True, **kwargs):
factor=3, whspace=0.05, showDvals=True, lines=None, **kwargs):
""" Generate a grid corner plot
Parameters
......@@ -75,6 +75,8 @@ def gridcorner(D, xyz, labels=None, projection='max_slice', max_n_ticks=4,
showDvals: bool
If true (default) show the D values on the right-hand-side of the
1D plots and add a label.
lines: array_like
N-dimensional list of values to delineate.
Returns
-------
......@@ -106,7 +108,7 @@ def gridcorner(D, xyz, labels=None, projection='max_slice', max_n_ticks=4,
for i in range(ndim):
projection_1D(
axes[i, i], xyz[i], D, i, projection=projection,
showDvals=showDvals, **kwargs)
showDvals=showDvals, lines=lines, **kwargs)
for j in range(ndim):
ax = axes[i, j]
......@@ -129,7 +131,7 @@ def gridcorner(D, xyz, labels=None, projection='max_slice', max_n_ticks=4,
ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j,
ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j, lines=lines,
projection=projection, **kwargs)
if labels:
......@@ -142,27 +144,32 @@ def gridcorner(D, xyz, labels=None, projection='max_slice', max_n_ticks=4,
return fig, axes
def projection_2D(ax, x, y, D, xidx, yidx, projection, **kwargs):
def projection_2D(ax, x, y, D, xidx, yidx, projection, lines=None, **kwargs):
flat_idxs = range(D.ndim)
flat_idxs.remove(xidx)
flat_idxs.remove(yidx)
D2D = projection(D, axis=tuple(flat_idxs), **kwargs)
X, Y = np.meshgrid(x, y, indexing='ij')
pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max())
if lines:
ax.axhline(lines[xidx], lw=0.5, color='w')
ax.axvline(lines[yidx], lw=0.5, color='w')
return ax, pax
def projection_1D(ax, x, D, xidx, projection, showDvals=True,
def projection_1D(ax, x, D, xidx, projection, showDvals=True, lines=None,
**kwargs):
flat_idxs = range(D.ndim)
flat_idxs.remove(xidx)
D1D = projection(D, axis=tuple(flat_idxs), **kwargs)
ax.plot(x, D1D)
ax.plot(x, D1D, color='k')
if showDvals:
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
else:
ax.yaxis.set_ticklabels([])
if lines:
ax.axvline(lines[xidx], lw=0.5, color='C0')
return ax
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment