diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4e622e2fa8150966d4ff5bde7c2dd50dfdd92586 --- /dev/null +++ b/README.md @@ -0,0 +1,40 @@ +# Projection matrix plot + +A `corner` plot for an array of dependent values + +Given an `N` dimensional set of data (i.e. some function evaluated over a grid +of coordinates), plot all possible 1D and 2D projections in the style of a +[`corner` plot](http://corner.readthedocs.io/en/latest/pages/quickstart.html). + +# Example + +Generating some fake data and plotting: + +```python +import numpy as np +import projection_matrix.projection_matrix as pmp + +# Generate example data +x = np.linspace(0, 1, 50) +y = np.linspace(20, 30, 60) +z = np.linspace(-2, -1, 70) +x0, y0, z0 = 0.5, 22, -1.5 +X, Y, Z = np.meshgrid(x, y, z, indexing='ij') +sigmax = 0.1 +sigmay = 1 +sigmaz = 0.2 +D = (np.exp(-(X-x0)**2/sigmax**2) + + np.exp(-(Y-y0)**2/sigmay**2) + + np.exp(-(Z-z0)**2/sigmaz**2)) + +fig, axes = pmp( + D, xyz=[x, y, z], labels=['x', 'y', 'z', 'D'], projection=np.max) +fig.savefig('example') +``` + + +# Acknowledgements + +The code uses both the central idea and some specific code from +[corner.py](https://github.com/dfm/corner.py) + diff --git a/example.png b/example.png new file mode 100644 index 0000000000000000000000000000000000000000..68fb7df6cea4d138dd451a3a9ba226063fbdf184 Binary files /dev/null and b/example.png differ diff --git a/example.py b/example.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed9dcd3958a3977e34de550a4bc81a8078425b7 --- /dev/null +++ b/example.py @@ -0,0 +1,20 @@ +import numpy as np +import projection_matrix.projection_matrix as pmp + +# Generate example data +x = np.linspace(0, 1, 50) +y = np.linspace(20, 30, 60) +z = np.linspace(-2, -1, 70) +x0, y0, z0 = 0.5, 22, -1.5 +X, Y, Z = np.meshgrid(x, y, z, indexing='ij') +sigmax = 0.1 +sigmay = 1 +sigmaz = 0.2 +D = (np.exp(-(X-x0)**2/sigmax**2) + + np.exp(-(Y-y0)**2/sigmay**2) + + np.exp(-(Z-z0)**2/sigmaz**2)) + +fig, axes = pmp( + D, xyz=[x, y, z], labels=['x', 'y', 'z', 'D'], projection=np.max) +fig.savefig('example') + diff --git a/projection_matrix/__init__.py b/projection_matrix/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a8b37d876720bec77a2f3ec0c2c40801890ae41 --- /dev/null +++ b/projection_matrix/__init__.py @@ -0,0 +1 @@ +from .projection_matrix import projection_matrix diff --git a/projection_matrix/projection_matrix.py b/projection_matrix/projection_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..f40cdd54d250d4a85c773c4322c50c25ad53dd6b --- /dev/null +++ b/projection_matrix/projection_matrix.py @@ -0,0 +1,102 @@ +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.ticker import MaxNLocator + + +def projection_matrix(D, xyz, labels=None, projection=np.max, max_n_ticks=4, + factor=3): + """ Generate a projection matrix plot + + Parameters + ---------- + D: array_like + N-dimensional data to plot, `D.shape` should be `(n1, n2,..., nn)`, + where `ni`, is the number of grid points along dimension `i`. + xyz: list + List of 1-dimensional arrays of coordinates. `xyz[i]` should have + length `ni` (see help for `D`). + labels: list + N+1 length list of labels; the first N correspond to the coordinates + labels, the final label is for the dependent variable. + projection: func + Function to use for projection, must take an `axis` argument. Default + is `np.max()`, to project out a slice along the maximum. + max_n_ticks: int + Number of ticks for x and y axis of the `pcolormesh` plots + factor: float + Controls the size of one window + + Returns + ------- + fig, axes: + The figure and NxN set of axes + + """ + ndim = D.ndim + lbdim = 0.5 * factor # size of left/bottom margin + trdim = 0.2 * factor # size of top/right margin + whspace = 0.05 # w/hspace size + plotdim = factor * ndim + factor * (ndim - 1.) * whspace + dim = lbdim + plotdim + trdim + + fig, axes = plt.subplots(ndim, ndim, figsize=(dim, dim)) + + # Format the figure. + lb = lbdim / dim + tr = (lbdim + plotdim) / dim + fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr, + wspace=whspace, hspace=whspace) + for i in range(ndim): + projection_1D(axes[i, i], xyz[i], D, i, projection=projection) + for j in range(ndim): + ax = axes[i, j] + + if j > i: + ax.set_frame_on(False) + ax.set_xticks([]) + ax.set_yticks([]) + continue + + ax.get_shared_x_axes().join(axes[ndim-1, j], ax) + if i < ndim - 1: + ax.set_xticklabels([]) + if j < i: + ax.get_shared_y_axes().join(axes[i, i-1], ax) + if j > 0: + ax.set_yticklabels([]) + if j == i: + continue + + ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper")) + ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper")) + + ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j, + projection=projection) + + if labels: + for i in range(ndim): + axes[-1, i].set_xlabel(labels[i]) + if i > 0: + axes[i, 0].set_ylabel(labels[i]) + axes[i, i].set_ylabel(labels[-1]) + return fig, axes + + +def projection_2D(ax, x, y, D, xidx, yidx, projection): + flat_idxs = range(D.ndim) + flat_idxs.remove(xidx) + flat_idxs.remove(yidx) + D2D = projection(D, axis=tuple(flat_idxs)) + X, Y = np.meshgrid(x, y, indexing='ij') + pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max()) + return ax, pax + + +def projection_1D(ax, x, D, xidx, projection): + flat_idxs = range(D.ndim) + flat_idxs.remove(xidx) + D1D = projection(D, axis=tuple(flat_idxs)) + ax.plot(x, D1D) + ax.yaxis.tick_right() + ax.yaxis.set_label_position("right") + return ax diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..6f6dab0b20853f71e46597cac6b4fff791540848 --- /dev/null +++ b/setup.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python + +from distutils.core import setup + +setup(name='projection_matrix', + version='0.1', + author='Gregory Ashton', + author_email='gregory.ashton@ligo.org', + packages=['projection_matrix'], + )