Select Git revision
condor_submit_Interpolate0001_t_10M.sub
Forked from
Xisco Jimenez Forteza / RDStackingProject
Source project has a limited visibility.
gridcorner.py 5.72 KiB
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from scipy.misc import logsumexp
def log_mean(loga, axis):
""" Calculate the log(<a>) mean
Given `N` logged value `log`, calculate the log_mean
`log(<loga>)=log(sum(np.exp(loga))) - log(N)`. Useful for marginalizing
over logged likelihoods for example.
Parameters
----------
loga: array_like
Input_array.
axies: None or int or type of ints, optional
Axis or axes over which the sum is taken. By default axis is None, and
all elements are summed.
Returns
-------
log_mean: ndarry
The logged average value (shape loga.shape)
"""
loga = np.array(loga)
N = np.prod([loga.shape[i] for i in axis])
return logsumexp(loga, axis) - np.log(N)
def max_slice(D, axis):
""" Return the slice along the given axis """
idxs = [range(D.shape[j]) for j in range(D.ndim)]
max_idx = list(np.unravel_index(D.argmax(), D.shape))
for k in np.atleast_1d(axis):
idxs[k] = [max_idx[k]]
res = np.squeeze(D[np.ix_(*tuple(idxs))])
return res
def idx_array_slice(D, axis, slice_idx):
""" Return the slice along the given axis """
idxs = [range(D.shape[j]) for j in range(D.ndim)]
for k in np.atleast_1d(axis):
idxs[k] = [slice_idx[k]]
res = np.squeeze(D[np.ix_(*tuple(idxs))])
return res
def gridcorner(D, xyz, labels=None, projection='max_slice', max_n_ticks=4,
factor=3, whspace=0.05, showDvals=True, lines=None, **kwargs):
""" Generate a grid corner plot
Parameters
----------
D: array_like
N-dimensional data to plot, `D.shape` should be `(n1, n2,..., nN)`,
where `N`, is the number of grid points along dimension `i`.
xyz: list
List of 1-dimensional arrays of coordinates. `xyz[i]` should have
length `N` (see help for `D`).
labels: list
N+1 length list of labels; the first N correspond to the coordinates
labels, the final label is for the dependent (D) variable.
projection: str or func
If a string, one of `{"log_mean", "max_slice"} to use inbuilt functions
to calculate either the logged mean or maximum slice projection. Else
a function to use for projection, must take an `axis` argument. Default
is `gridcorner.max_slice()`, to project out a slice along the
maximum.
max_n_ticks: int
Number of ticks for x and y axis of the `pcolormesh` plots.
factor: float
Controls the size of one window.
showDvals: bool
If true (default) show the D values on the right-hand-side of the
1D plots and add a label.
lines: array_like
N-dimensional list of values to delineate.
Returns
-------
fig, axes:
The figure and NxN set of axes
"""
ndim = D.ndim
lbdim = 0.4 * factor # size of left/bottom margin
trdim = 0.2 * factor # size of top/right margin
plotdim = factor * ndim + factor * (ndim - 1.) * whspace
dim = lbdim + plotdim + trdim
if type(projection) == str:
if projection in ['log_mean']:
projection = log_mean
elif projection in ['max_slice']:
projection = max_slice
else:
raise ValueError("Projection {} not understood".format(projection))
fig, axes = plt.subplots(ndim, ndim, figsize=(dim, dim))
# Format the figure.
lb = lbdim / dim
tr = (lbdim + plotdim) / dim
fig.subplots_adjust(left=lb, bottom=lb, right=0.98*tr, top=tr,
wspace=whspace, hspace=whspace)
for i in range(ndim):
projection_1D(
axes[i, i], xyz[i], D, i, projection=projection,
showDvals=showDvals, lines=lines, **kwargs)
for j in range(ndim):
ax = axes[i, j]
if j > i:
ax.set_frame_on(False)
ax.set_xticks([])
ax.set_yticks([])
continue
ax.get_shared_x_axes().join(axes[ndim-1, j], ax)
if i < ndim - 1:
ax.set_xticklabels([])
if j < i:
ax.get_shared_y_axes().join(axes[i, i-1], ax)
if j > 0:
ax.set_yticklabels([])
if j == i:
continue
ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j, lines=lines,
projection=projection, **kwargs)
if labels:
for i in range(ndim):
axes[-1, i].set_xlabel(labels[i])
if i > 0:
axes[i, 0].set_ylabel(labels[i])
if showDvals:
axes[i, i].set_ylabel(labels[-1])
return fig, axes
def projection_2D(ax, x, y, D, xidx, yidx, projection, lines=None, **kwargs):
flat_idxs = range(D.ndim)
flat_idxs.remove(xidx)
flat_idxs.remove(yidx)
D2D = projection(D, axis=tuple(flat_idxs), **kwargs)
X, Y = np.meshgrid(x, y, indexing='ij')
pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max())
if lines:
ax.axhline(lines[xidx], lw=0.5, color='w')
ax.axvline(lines[yidx], lw=0.5, color='w')
return ax, pax
def projection_1D(ax, x, D, xidx, projection, showDvals=True, lines=None,
**kwargs):
flat_idxs = range(D.ndim)
flat_idxs.remove(xidx)
D1D = projection(D, axis=tuple(flat_idxs), **kwargs)
ax.plot(x, D1D, color='k')
if showDvals:
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
else:
ax.yaxis.set_ticklabels([])
if lines:
ax.axvline(lines[xidx], lw=0.5, color='C0')
return ax