From bb07e6084226aaf7d497d56e6ac9d26f5f65fbec Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Mon, 30 Oct 2017 14:04:41 +0100
Subject: [PATCH] Adds new projection methods and improve docs

---
 projection_matrix/__init__.py          |  2 +-
 projection_matrix/projection_matrix.py | 69 +++++++++++++++++++++-----
 2 files changed, 58 insertions(+), 13 deletions(-)

diff --git a/projection_matrix/__init__.py b/projection_matrix/__init__.py
index 5b1f2f8..3a8b37d 100644
--- a/projection_matrix/__init__.py
+++ b/projection_matrix/__init__.py
@@ -1 +1 @@
-from .projection_matrix import projection_matrix, slice_max
+from .projection_matrix import projection_matrix
diff --git a/projection_matrix/projection_matrix.py b/projection_matrix/projection_matrix.py
index 769618c..b3c808a 100644
--- a/projection_matrix/projection_matrix.py
+++ b/projection_matrix/projection_matrix.py
@@ -1,9 +1,34 @@
 import numpy as np
 import matplotlib.pyplot as plt
 from matplotlib.ticker import MaxNLocator
+from scipy.misc import logsumexp
 
 
-def slice_max(D, axis):
+def log_mean(loga, axis):
+    """ Calculate the log(<a>) mean
+
+    Given `N` logged value `log`, calculate the log_mean
+    `log(<loga>)=log(sum(np.exp(loga))) - log(N)`. Useful for marginalizing
+    over logged likelihoods for example.
+
+    Parameters
+    ----------
+    loga: array_like
+        Input_array.
+    axies: None or int or type of ints, optional
+        Axis or axes over which the sum is taken. By default axis is None, and
+        all elements are summed.
+    Returns
+    -------
+    log_mean: ndarry
+        The logged average value (shape loga.shape)
+    """
+    loga = np.array(loga)
+    N = np.prod([loga.shape[i] for i in axis])
+    return logsumexp(loga, axis) - np.log(N)
+
+
+def max_slice(D, axis):
     """ Return the slice along the given axis """
     idxs = [range(D.shape[j]) for j in range(D.ndim)]
     max_idx = list(np.unravel_index(D.argmax(), D.shape))
@@ -13,8 +38,17 @@ def slice_max(D, axis):
     return res
 
 
-def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
-                      factor=3):
+def idx_array_slice(D, axis, slice_idx):
+    """ Return the slice along the given axis """
+    idxs = [range(D.shape[j]) for j in range(D.ndim)]
+    for k in np.atleast_1d(axis):
+        idxs[k] = [slice_idx[k]]
+    res = np.squeeze(D[np.ix_(*tuple(idxs))])
+    return res
+
+
+def projection_matrix(D, xyz, labels=None, projection='max_slice',
+                      max_n_ticks=4, factor=3, **kwargs):
     """ Generate a projection matrix plot
 
     Parameters
@@ -28,9 +62,11 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
     labels: list
         N+1 length list of labels; the first N correspond to the coordinates
         labels, the final label is for the dependent variable.
-    projection: func
-        Function to use for projection, must take an `axis` argument. Default
-        is `projection_matrix.slice_max()`, to project out a slice along the
+    projection: str or func
+        If a string, one of `{"log_mean", "max_slice"} to use inbuilt functions
+        to calculate either the logged mean or maximum slice projection. Else
+        a function to use for projection, must take an `axis` argument. Default
+        is `projection_matrix.max_slice()`, to project out a slice along the
         maximum.
     max_n_ticks: int
         Number of ticks for x and y axis of the `pcolormesh` plots
@@ -50,6 +86,14 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
     plotdim = factor * ndim + factor * (ndim - 1.) * whspace
     dim = lbdim + plotdim + trdim
 
+    if type(projection) == str:
+        if projection in ['log_mean']:
+            projection = log_mean
+        elif projection in ['max_slice']:
+            projection = max_slice
+        else:
+            raise ValueError("Projection {} not understood".format(projection))
+
     fig, axes = plt.subplots(ndim, ndim, figsize=(dim, dim))
 
     # Format the figure.
@@ -58,7 +102,8 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
     fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr,
                         wspace=whspace, hspace=whspace)
     for i in range(ndim):
-        projection_1D(axes[i, i], xyz[i], D, i, projection=projection)
+        projection_1D(
+            axes[i, i], xyz[i], D, i, projection=projection, **kwargs)
         for j in range(ndim):
             ax = axes[i, j]
 
@@ -82,7 +127,7 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
             ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
 
             ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j,
-                                    projection=projection)
+                                    projection=projection, **kwargs)
 
     if labels:
         for i in range(ndim):
@@ -93,20 +138,20 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
     return fig, axes
 
 
-def projection_2D(ax, x, y, D, xidx, yidx, projection):
+def projection_2D(ax, x, y, D, xidx, yidx, projection, **kwargs):
     flat_idxs = range(D.ndim)
     flat_idxs.remove(xidx)
     flat_idxs.remove(yidx)
-    D2D = projection(D, axis=tuple(flat_idxs))
+    D2D = projection(D, axis=tuple(flat_idxs), **kwargs)
     X, Y = np.meshgrid(x, y, indexing='ij')
     pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max())
     return ax, pax
 
 
-def projection_1D(ax, x, D, xidx, projection):
+def projection_1D(ax, x, D, xidx, projection, **kwargs):
     flat_idxs = range(D.ndim)
     flat_idxs.remove(xidx)
-    D1D = projection(D, axis=tuple(flat_idxs))
+    D1D = projection(D, axis=tuple(flat_idxs), **kwargs)
     ax.plot(x, D1D)
     ax.yaxis.tick_right()
     ax.yaxis.set_label_position("right")
-- 
GitLab