From bb07e6084226aaf7d497d56e6ac9d26f5f65fbec Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Mon, 30 Oct 2017 14:04:41 +0100 Subject: [PATCH] Adds new projection methods and improve docs --- projection_matrix/__init__.py | 2 +- projection_matrix/projection_matrix.py | 69 +++++++++++++++++++++----- 2 files changed, 58 insertions(+), 13 deletions(-) diff --git a/projection_matrix/__init__.py b/projection_matrix/__init__.py index 5b1f2f8..3a8b37d 100644 --- a/projection_matrix/__init__.py +++ b/projection_matrix/__init__.py @@ -1 +1 @@ -from .projection_matrix import projection_matrix, slice_max +from .projection_matrix import projection_matrix diff --git a/projection_matrix/projection_matrix.py b/projection_matrix/projection_matrix.py index 769618c..b3c808a 100644 --- a/projection_matrix/projection_matrix.py +++ b/projection_matrix/projection_matrix.py @@ -1,9 +1,34 @@ import numpy as np import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator +from scipy.misc import logsumexp -def slice_max(D, axis): +def log_mean(loga, axis): + """ Calculate the log(<a>) mean + + Given `N` logged value `log`, calculate the log_mean + `log(<loga>)=log(sum(np.exp(loga))) - log(N)`. Useful for marginalizing + over logged likelihoods for example. + + Parameters + ---------- + loga: array_like + Input_array. + axies: None or int or type of ints, optional + Axis or axes over which the sum is taken. By default axis is None, and + all elements are summed. + Returns + ------- + log_mean: ndarry + The logged average value (shape loga.shape) + """ + loga = np.array(loga) + N = np.prod([loga.shape[i] for i in axis]) + return logsumexp(loga, axis) - np.log(N) + + +def max_slice(D, axis): """ Return the slice along the given axis """ idxs = [range(D.shape[j]) for j in range(D.ndim)] max_idx = list(np.unravel_index(D.argmax(), D.shape)) @@ -13,8 +38,17 @@ def slice_max(D, axis): return res -def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4, - factor=3): +def idx_array_slice(D, axis, slice_idx): + """ Return the slice along the given axis """ + idxs = [range(D.shape[j]) for j in range(D.ndim)] + for k in np.atleast_1d(axis): + idxs[k] = [slice_idx[k]] + res = np.squeeze(D[np.ix_(*tuple(idxs))]) + return res + + +def projection_matrix(D, xyz, labels=None, projection='max_slice', + max_n_ticks=4, factor=3, **kwargs): """ Generate a projection matrix plot Parameters @@ -28,9 +62,11 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4, labels: list N+1 length list of labels; the first N correspond to the coordinates labels, the final label is for the dependent variable. - projection: func - Function to use for projection, must take an `axis` argument. Default - is `projection_matrix.slice_max()`, to project out a slice along the + projection: str or func + If a string, one of `{"log_mean", "max_slice"} to use inbuilt functions + to calculate either the logged mean or maximum slice projection. Else + a function to use for projection, must take an `axis` argument. Default + is `projection_matrix.max_slice()`, to project out a slice along the maximum. max_n_ticks: int Number of ticks for x and y axis of the `pcolormesh` plots @@ -50,6 +86,14 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4, plotdim = factor * ndim + factor * (ndim - 1.) * whspace dim = lbdim + plotdim + trdim + if type(projection) == str: + if projection in ['log_mean']: + projection = log_mean + elif projection in ['max_slice']: + projection = max_slice + else: + raise ValueError("Projection {} not understood".format(projection)) + fig, axes = plt.subplots(ndim, ndim, figsize=(dim, dim)) # Format the figure. @@ -58,7 +102,8 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4, fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr, wspace=whspace, hspace=whspace) for i in range(ndim): - projection_1D(axes[i, i], xyz[i], D, i, projection=projection) + projection_1D( + axes[i, i], xyz[i], D, i, projection=projection, **kwargs) for j in range(ndim): ax = axes[i, j] @@ -82,7 +127,7 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4, ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper")) ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j, - projection=projection) + projection=projection, **kwargs) if labels: for i in range(ndim): @@ -93,20 +138,20 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4, return fig, axes -def projection_2D(ax, x, y, D, xidx, yidx, projection): +def projection_2D(ax, x, y, D, xidx, yidx, projection, **kwargs): flat_idxs = range(D.ndim) flat_idxs.remove(xidx) flat_idxs.remove(yidx) - D2D = projection(D, axis=tuple(flat_idxs)) + D2D = projection(D, axis=tuple(flat_idxs), **kwargs) X, Y = np.meshgrid(x, y, indexing='ij') pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max()) return ax, pax -def projection_1D(ax, x, D, xidx, projection): +def projection_1D(ax, x, D, xidx, projection, **kwargs): flat_idxs = range(D.ndim) flat_idxs.remove(xidx) - D1D = projection(D, axis=tuple(flat_idxs)) + D1D = projection(D, axis=tuple(flat_idxs), **kwargs) ax.plot(x, D1D) ax.yaxis.tick_right() ax.yaxis.set_label_position("right") -- GitLab