Skip to content
Snippets Groups Projects
Commit bb07e608 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adds new projection methods and improve docs

parent 78457b59
No related branches found
No related tags found
No related merge requests found
from .projection_matrix import projection_matrix, slice_max
from .projection_matrix import projection_matrix
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from scipy.misc import logsumexp
def slice_max(D, axis):
def log_mean(loga, axis):
""" Calculate the log(<a>) mean
Given `N` logged value `log`, calculate the log_mean
`log(<loga>)=log(sum(np.exp(loga))) - log(N)`. Useful for marginalizing
over logged likelihoods for example.
Parameters
----------
loga: array_like
Input_array.
axies: None or int or type of ints, optional
Axis or axes over which the sum is taken. By default axis is None, and
all elements are summed.
Returns
-------
log_mean: ndarry
The logged average value (shape loga.shape)
"""
loga = np.array(loga)
N = np.prod([loga.shape[i] for i in axis])
return logsumexp(loga, axis) - np.log(N)
def max_slice(D, axis):
""" Return the slice along the given axis """
idxs = [range(D.shape[j]) for j in range(D.ndim)]
max_idx = list(np.unravel_index(D.argmax(), D.shape))
......@@ -13,8 +38,17 @@ def slice_max(D, axis):
return res
def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
factor=3):
def idx_array_slice(D, axis, slice_idx):
""" Return the slice along the given axis """
idxs = [range(D.shape[j]) for j in range(D.ndim)]
for k in np.atleast_1d(axis):
idxs[k] = [slice_idx[k]]
res = np.squeeze(D[np.ix_(*tuple(idxs))])
return res
def projection_matrix(D, xyz, labels=None, projection='max_slice',
max_n_ticks=4, factor=3, **kwargs):
""" Generate a projection matrix plot
Parameters
......@@ -28,9 +62,11 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
labels: list
N+1 length list of labels; the first N correspond to the coordinates
labels, the final label is for the dependent variable.
projection: func
Function to use for projection, must take an `axis` argument. Default
is `projection_matrix.slice_max()`, to project out a slice along the
projection: str or func
If a string, one of `{"log_mean", "max_slice"} to use inbuilt functions
to calculate either the logged mean or maximum slice projection. Else
a function to use for projection, must take an `axis` argument. Default
is `projection_matrix.max_slice()`, to project out a slice along the
maximum.
max_n_ticks: int
Number of ticks for x and y axis of the `pcolormesh` plots
......@@ -50,6 +86,14 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
plotdim = factor * ndim + factor * (ndim - 1.) * whspace
dim = lbdim + plotdim + trdim
if type(projection) == str:
if projection in ['log_mean']:
projection = log_mean
elif projection in ['max_slice']:
projection = max_slice
else:
raise ValueError("Projection {} not understood".format(projection))
fig, axes = plt.subplots(ndim, ndim, figsize=(dim, dim))
# Format the figure.
......@@ -58,7 +102,8 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr,
wspace=whspace, hspace=whspace)
for i in range(ndim):
projection_1D(axes[i, i], xyz[i], D, i, projection=projection)
projection_1D(
axes[i, i], xyz[i], D, i, projection=projection, **kwargs)
for j in range(ndim):
ax = axes[i, j]
......@@ -82,7 +127,7 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j,
projection=projection)
projection=projection, **kwargs)
if labels:
for i in range(ndim):
......@@ -93,20 +138,20 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
return fig, axes
def projection_2D(ax, x, y, D, xidx, yidx, projection):
def projection_2D(ax, x, y, D, xidx, yidx, projection, **kwargs):
flat_idxs = range(D.ndim)
flat_idxs.remove(xidx)
flat_idxs.remove(yidx)
D2D = projection(D, axis=tuple(flat_idxs))
D2D = projection(D, axis=tuple(flat_idxs), **kwargs)
X, Y = np.meshgrid(x, y, indexing='ij')
pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max())
return ax, pax
def projection_1D(ax, x, D, xidx, projection):
def projection_1D(ax, x, D, xidx, projection, **kwargs):
flat_idxs = range(D.ndim)
flat_idxs.remove(xidx)
D1D = projection(D, axis=tuple(flat_idxs))
D1D = projection(D, axis=tuple(flat_idxs), **kwargs)
ax.plot(x, D1D)
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment