Commit bb07e608 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adds new projection methods and improve docs

parent 78457b59
from .projection_matrix import projection_matrix, slice_max from .projection_matrix import projection_matrix
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator from matplotlib.ticker import MaxNLocator
from scipy.misc import logsumexp
def slice_max(D, axis): def log_mean(loga, axis):
""" Calculate the log(<a>) mean
Given `N` logged value `log`, calculate the log_mean
`log(<loga>)=log(sum(np.exp(loga))) - log(N)`. Useful for marginalizing
over logged likelihoods for example.
Parameters
----------
loga: array_like
Input_array.
axies: None or int or type of ints, optional
Axis or axes over which the sum is taken. By default axis is None, and
all elements are summed.
Returns
-------
log_mean: ndarry
The logged average value (shape loga.shape)
"""
loga = np.array(loga)
N = np.prod([loga.shape[i] for i in axis])
return logsumexp(loga, axis) - np.log(N)
def max_slice(D, axis):
""" Return the slice along the given axis """ """ Return the slice along the given axis """
idxs = [range(D.shape[j]) for j in range(D.ndim)] idxs = [range(D.shape[j]) for j in range(D.ndim)]
max_idx = list(np.unravel_index(D.argmax(), D.shape)) max_idx = list(np.unravel_index(D.argmax(), D.shape))
...@@ -13,8 +38,17 @@ def slice_max(D, axis): ...@@ -13,8 +38,17 @@ def slice_max(D, axis):
return res return res
def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4, def idx_array_slice(D, axis, slice_idx):
factor=3): """ Return the slice along the given axis """
idxs = [range(D.shape[j]) for j in range(D.ndim)]
for k in np.atleast_1d(axis):
idxs[k] = [slice_idx[k]]
res = np.squeeze(D[np.ix_(*tuple(idxs))])
return res
def projection_matrix(D, xyz, labels=None, projection='max_slice',
max_n_ticks=4, factor=3, **kwargs):
""" Generate a projection matrix plot """ Generate a projection matrix plot
Parameters Parameters
...@@ -28,9 +62,11 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4, ...@@ -28,9 +62,11 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
labels: list labels: list
N+1 length list of labels; the first N correspond to the coordinates N+1 length list of labels; the first N correspond to the coordinates
labels, the final label is for the dependent variable. labels, the final label is for the dependent variable.
projection: func projection: str or func
Function to use for projection, must take an `axis` argument. Default If a string, one of `{"log_mean", "max_slice"} to use inbuilt functions
is `projection_matrix.slice_max()`, to project out a slice along the to calculate either the logged mean or maximum slice projection. Else
a function to use for projection, must take an `axis` argument. Default
is `projection_matrix.max_slice()`, to project out a slice along the
maximum. maximum.
max_n_ticks: int max_n_ticks: int
Number of ticks for x and y axis of the `pcolormesh` plots Number of ticks for x and y axis of the `pcolormesh` plots
...@@ -50,6 +86,14 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4, ...@@ -50,6 +86,14 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
plotdim = factor * ndim + factor * (ndim - 1.) * whspace plotdim = factor * ndim + factor * (ndim - 1.) * whspace
dim = lbdim + plotdim + trdim dim = lbdim + plotdim + trdim
if type(projection) == str:
if projection in ['log_mean']:
projection = log_mean
elif projection in ['max_slice']:
projection = max_slice
else:
raise ValueError("Projection {} not understood".format(projection))
fig, axes = plt.subplots(ndim, ndim, figsize=(dim, dim)) fig, axes = plt.subplots(ndim, ndim, figsize=(dim, dim))
# Format the figure. # Format the figure.
...@@ -58,7 +102,8 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4, ...@@ -58,7 +102,8 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr, fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr,
wspace=whspace, hspace=whspace) wspace=whspace, hspace=whspace)
for i in range(ndim): for i in range(ndim):
projection_1D(axes[i, i], xyz[i], D, i, projection=projection) projection_1D(
axes[i, i], xyz[i], D, i, projection=projection, **kwargs)
for j in range(ndim): for j in range(ndim):
ax = axes[i, j] ax = axes[i, j]
...@@ -82,7 +127,7 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4, ...@@ -82,7 +127,7 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper")) ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j, ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j,
projection=projection) projection=projection, **kwargs)
if labels: if labels:
for i in range(ndim): for i in range(ndim):
...@@ -93,20 +138,20 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4, ...@@ -93,20 +138,20 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
return fig, axes return fig, axes
def projection_2D(ax, x, y, D, xidx, yidx, projection): def projection_2D(ax, x, y, D, xidx, yidx, projection, **kwargs):
flat_idxs = range(D.ndim) flat_idxs = range(D.ndim)
flat_idxs.remove(xidx) flat_idxs.remove(xidx)
flat_idxs.remove(yidx) flat_idxs.remove(yidx)
D2D = projection(D, axis=tuple(flat_idxs)) D2D = projection(D, axis=tuple(flat_idxs), **kwargs)
X, Y = np.meshgrid(x, y, indexing='ij') X, Y = np.meshgrid(x, y, indexing='ij')
pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max()) pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max())
return ax, pax return ax, pax
def projection_1D(ax, x, D, xidx, projection): def projection_1D(ax, x, D, xidx, projection, **kwargs):
flat_idxs = range(D.ndim) flat_idxs = range(D.ndim)
flat_idxs.remove(xidx) flat_idxs.remove(xidx)
D1D = projection(D, axis=tuple(flat_idxs)) D1D = projection(D, axis=tuple(flat_idxs), **kwargs)
ax.plot(x, D1D) ax.plot(x, D1D)
ax.yaxis.tick_right() ax.yaxis.tick_right()
ax.yaxis.set_label_position("right") ax.yaxis.set_label_position("right")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment