diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4e622e2fa8150966d4ff5bde7c2dd50dfdd92586
--- /dev/null
+++ b/README.md
@@ -0,0 +1,40 @@
+# Projection matrix plot
+
+A `corner` plot for an array of dependent values
+
+Given an `N` dimensional set of data (i.e. some function evaluated over a grid
+of coordinates), plot all possible 1D and 2D projections in the style of a
+[`corner` plot](http://corner.readthedocs.io/en/latest/pages/quickstart.html).
+
+# Example
+
+Generating some fake data and plotting:
+
+```python
+import numpy as np
+import projection_matrix.projection_matrix as pmp
+
+# Generate example data
+x = np.linspace(0, 1, 50)
+y = np.linspace(20, 30, 60)
+z = np.linspace(-2, -1, 70)
+x0, y0, z0 = 0.5, 22, -1.5
+X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
+sigmax = 0.1
+sigmay = 1
+sigmaz = 0.2
+D = (np.exp(-(X-x0)**2/sigmax**2)
+     + np.exp(-(Y-y0)**2/sigmay**2)
+     + np.exp(-(Z-z0)**2/sigmaz**2))
+
+fig, axes = pmp(
+    D, xyz=[x, y, z], labels=['x', 'y', 'z', 'D'], projection=np.max)
+fig.savefig('example')
+```
+![Example plot](example.png)
+
+# Acknowledgements
+
+The code uses both the central idea and some specific code from
+[corner.py](https://github.com/dfm/corner.py)
+
diff --git a/example.png b/example.png
new file mode 100644
index 0000000000000000000000000000000000000000..68fb7df6cea4d138dd451a3a9ba226063fbdf184
Binary files /dev/null and b/example.png differ
diff --git a/example.py b/example.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ed9dcd3958a3977e34de550a4bc81a8078425b7
--- /dev/null
+++ b/example.py
@@ -0,0 +1,20 @@
+import numpy as np
+import projection_matrix.projection_matrix as pmp
+
+# Generate example data
+x = np.linspace(0, 1, 50)
+y = np.linspace(20, 30, 60)
+z = np.linspace(-2, -1, 70)
+x0, y0, z0 = 0.5, 22, -1.5
+X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
+sigmax = 0.1
+sigmay = 1
+sigmaz = 0.2
+D = (np.exp(-(X-x0)**2/sigmax**2)
+     + np.exp(-(Y-y0)**2/sigmay**2)
+     + np.exp(-(Z-z0)**2/sigmaz**2))
+
+fig, axes = pmp(
+    D, xyz=[x, y, z], labels=['x', 'y', 'z', 'D'], projection=np.max)
+fig.savefig('example')
+
diff --git a/projection_matrix/__init__.py b/projection_matrix/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a8b37d876720bec77a2f3ec0c2c40801890ae41
--- /dev/null
+++ b/projection_matrix/__init__.py
@@ -0,0 +1 @@
+from .projection_matrix import projection_matrix
diff --git a/projection_matrix/projection_matrix.py b/projection_matrix/projection_matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..f40cdd54d250d4a85c773c4322c50c25ad53dd6b
--- /dev/null
+++ b/projection_matrix/projection_matrix.py
@@ -0,0 +1,102 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.ticker import MaxNLocator
+
+
+def projection_matrix(D, xyz, labels=None, projection=np.max, max_n_ticks=4,
+                      factor=3):
+    """ Generate a projection matrix plot
+
+    Parameters
+    ----------
+    D: array_like
+        N-dimensional data to plot, `D.shape` should be  `(n1, n2,..., nn)`,
+        where `ni`, is the number of grid points along dimension `i`.
+    xyz: list
+        List of 1-dimensional arrays of coordinates. `xyz[i]` should have
+        length `ni` (see help for `D`).
+    labels: list
+        N+1 length list of labels; the first N correspond to the coordinates
+        labels, the final label is for the dependent variable.
+    projection: func
+        Function to use for projection, must take an `axis` argument. Default
+        is `np.max()`, to project out a slice along the maximum.
+    max_n_ticks: int
+        Number of ticks for x and y axis of the `pcolormesh` plots
+    factor: float
+        Controls the size of one window
+
+    Returns
+    -------
+    fig, axes:
+        The figure and NxN set of axes
+
+    """
+    ndim = D.ndim
+    lbdim = 0.5 * factor   # size of left/bottom margin
+    trdim = 0.2 * factor   # size of top/right margin
+    whspace = 0.05         # w/hspace size
+    plotdim = factor * ndim + factor * (ndim - 1.) * whspace
+    dim = lbdim + plotdim + trdim
+
+    fig, axes = plt.subplots(ndim, ndim, figsize=(dim, dim))
+
+    # Format the figure.
+    lb = lbdim / dim
+    tr = (lbdim + plotdim) / dim
+    fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr,
+                        wspace=whspace, hspace=whspace)
+    for i in range(ndim):
+        projection_1D(axes[i, i], xyz[i], D, i, projection=projection)
+        for j in range(ndim):
+            ax = axes[i, j]
+
+            if j > i:
+                ax.set_frame_on(False)
+                ax.set_xticks([])
+                ax.set_yticks([])
+                continue
+
+            ax.get_shared_x_axes().join(axes[ndim-1, j], ax)
+            if i < ndim - 1:
+                ax.set_xticklabels([])
+            if j < i:
+                ax.get_shared_y_axes().join(axes[i, i-1], ax)
+                if j > 0:
+                    ax.set_yticklabels([])
+            if j == i:
+                continue
+
+            ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
+            ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
+
+            ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j,
+                                    projection=projection)
+
+    if labels:
+        for i in range(ndim):
+            axes[-1, i].set_xlabel(labels[i])
+            if i > 0:
+                axes[i, 0].set_ylabel(labels[i])
+            axes[i, i].set_ylabel(labels[-1])
+    return fig, axes
+
+
+def projection_2D(ax, x, y, D, xidx, yidx, projection):
+    flat_idxs = range(D.ndim)
+    flat_idxs.remove(xidx)
+    flat_idxs.remove(yidx)
+    D2D = projection(D, axis=tuple(flat_idxs))
+    X, Y = np.meshgrid(x, y, indexing='ij')
+    pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max())
+    return ax, pax
+
+
+def projection_1D(ax, x, D, xidx, projection):
+    flat_idxs = range(D.ndim)
+    flat_idxs.remove(xidx)
+    D1D = projection(D, axis=tuple(flat_idxs))
+    ax.plot(x, D1D)
+    ax.yaxis.tick_right()
+    ax.yaxis.set_label_position("right")
+    return ax
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f6dab0b20853f71e46597cac6b4fff791540848
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,10 @@
+#!/usr/bin/env python
+
+from distutils.core import setup
+
+setup(name='projection_matrix',
+      version='0.1',
+      author='Gregory Ashton',
+      author_email='gregory.ashton@ligo.org',
+      packages=['projection_matrix'],
+      )