Commit 0a0105b7 authored by Gregory Ashton's avatar Gregory Ashton

Initial commit adding basic idea

parent db4a2853
# Projection matrix plot
A `corner` plot for an array of dependent values
Given an `N` dimensional set of data (i.e. some function evaluated over a grid
of coordinates), plot all possible 1D and 2D projections in the style of a
[`corner` plot](http://corner.readthedocs.io/en/latest/pages/quickstart.html).
# Example
Generating some fake data and plotting:
```python
import numpy as np
import projection_matrix.projection_matrix as pmp
# Generate example data
x = np.linspace(0, 1, 50)
y = np.linspace(20, 30, 60)
z = np.linspace(-2, -1, 70)
x0, y0, z0 = 0.5, 22, -1.5
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
sigmax = 0.1
sigmay = 1
sigmaz = 0.2
D = (np.exp(-(X-x0)**2/sigmax**2)
+ np.exp(-(Y-y0)**2/sigmay**2)
+ np.exp(-(Z-z0)**2/sigmaz**2))
fig, axes = pmp(
D, xyz=[x, y, z], labels=['x', 'y', 'z', 'D'], projection=np.max)
fig.savefig('example')
```
![Example plot](example.png)
# Acknowledgements
The code uses both the central idea and some specific code from
[corner.py](https://github.com/dfm/corner.py)
example.png

65.1 KB

import numpy as np
import projection_matrix.projection_matrix as pmp
# Generate example data
x = np.linspace(0, 1, 50)
y = np.linspace(20, 30, 60)
z = np.linspace(-2, -1, 70)
x0, y0, z0 = 0.5, 22, -1.5
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
sigmax = 0.1
sigmay = 1
sigmaz = 0.2
D = (np.exp(-(X-x0)**2/sigmax**2)
+ np.exp(-(Y-y0)**2/sigmay**2)
+ np.exp(-(Z-z0)**2/sigmaz**2))
fig, axes = pmp(
D, xyz=[x, y, z], labels=['x', 'y', 'z', 'D'], projection=np.max)
fig.savefig('example')
from .projection_matrix import projection_matrix
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
def projection_matrix(D, xyz, labels=None, projection=np.max, max_n_ticks=4,
factor=3):
""" Generate a projection matrix plot
Parameters
----------
D: array_like
N-dimensional data to plot, `D.shape` should be `(n1, n2,..., nn)`,
where `ni`, is the number of grid points along dimension `i`.
xyz: list
List of 1-dimensional arrays of coordinates. `xyz[i]` should have
length `ni` (see help for `D`).
labels: list
N+1 length list of labels; the first N correspond to the coordinates
labels, the final label is for the dependent variable.
projection: func
Function to use for projection, must take an `axis` argument. Default
is `np.max()`, to project out a slice along the maximum.
max_n_ticks: int
Number of ticks for x and y axis of the `pcolormesh` plots
factor: float
Controls the size of one window
Returns
-------
fig, axes:
The figure and NxN set of axes
"""
ndim = D.ndim
lbdim = 0.5 * factor # size of left/bottom margin
trdim = 0.2 * factor # size of top/right margin
whspace = 0.05 # w/hspace size
plotdim = factor * ndim + factor * (ndim - 1.) * whspace
dim = lbdim + plotdim + trdim
fig, axes = plt.subplots(ndim, ndim, figsize=(dim, dim))
# Format the figure.
lb = lbdim / dim
tr = (lbdim + plotdim) / dim
fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr,
wspace=whspace, hspace=whspace)
for i in range(ndim):
projection_1D(axes[i, i], xyz[i], D, i, projection=projection)
for j in range(ndim):
ax = axes[i, j]
if j > i:
ax.set_frame_on(False)
ax.set_xticks([])
ax.set_yticks([])
continue
ax.get_shared_x_axes().join(axes[ndim-1, j], ax)
if i < ndim - 1:
ax.set_xticklabels([])
if j < i:
ax.get_shared_y_axes().join(axes[i, i-1], ax)
if j > 0:
ax.set_yticklabels([])
if j == i:
continue
ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j,
projection=projection)
if labels:
for i in range(ndim):
axes[-1, i].set_xlabel(labels[i])
if i > 0:
axes[i, 0].set_ylabel(labels[i])
axes[i, i].set_ylabel(labels[-1])
return fig, axes
def projection_2D(ax, x, y, D, xidx, yidx, projection):
flat_idxs = range(D.ndim)
flat_idxs.remove(xidx)
flat_idxs.remove(yidx)
D2D = projection(D, axis=tuple(flat_idxs))
X, Y = np.meshgrid(x, y, indexing='ij')
pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max())
return ax, pax
def projection_1D(ax, x, D, xidx, projection):
flat_idxs = range(D.ndim)
flat_idxs.remove(xidx)
D1D = projection(D, axis=tuple(flat_idxs))
ax.plot(x, D1D)
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
return ax
#!/usr/bin/env python
from distutils.core import setup
setup(name='projection_matrix',
version='0.1',
author='Gregory Ashton',
author_email='gregory.ashton@ligo.org',
packages=['projection_matrix'],
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment