Skip to content
Snippets Groups Projects
Commit 0a0105b7 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Initial commit adding basic idea

parent db4a2853
No related branches found
No related tags found
No related merge requests found
# Projection matrix plot
A `corner` plot for an array of dependent values
Given an `N` dimensional set of data (i.e. some function evaluated over a grid
of coordinates), plot all possible 1D and 2D projections in the style of a
[`corner` plot](http://corner.readthedocs.io/en/latest/pages/quickstart.html).
# Example
Generating some fake data and plotting:
```python
import numpy as np
import projection_matrix.projection_matrix as pmp
# Generate example data
x = np.linspace(0, 1, 50)
y = np.linspace(20, 30, 60)
z = np.linspace(-2, -1, 70)
x0, y0, z0 = 0.5, 22, -1.5
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
sigmax = 0.1
sigmay = 1
sigmaz = 0.2
D = (np.exp(-(X-x0)**2/sigmax**2)
+ np.exp(-(Y-y0)**2/sigmay**2)
+ np.exp(-(Z-z0)**2/sigmaz**2))
fig, axes = pmp(
D, xyz=[x, y, z], labels=['x', 'y', 'z', 'D'], projection=np.max)
fig.savefig('example')
```
![Example plot](example.png)
# Acknowledgements
The code uses both the central idea and some specific code from
[corner.py](https://github.com/dfm/corner.py)
example.png

65.1 KiB

import numpy as np
import projection_matrix.projection_matrix as pmp
# Generate example data
x = np.linspace(0, 1, 50)
y = np.linspace(20, 30, 60)
z = np.linspace(-2, -1, 70)
x0, y0, z0 = 0.5, 22, -1.5
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
sigmax = 0.1
sigmay = 1
sigmaz = 0.2
D = (np.exp(-(X-x0)**2/sigmax**2)
+ np.exp(-(Y-y0)**2/sigmay**2)
+ np.exp(-(Z-z0)**2/sigmaz**2))
fig, axes = pmp(
D, xyz=[x, y, z], labels=['x', 'y', 'z', 'D'], projection=np.max)
fig.savefig('example')
from .projection_matrix import projection_matrix
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
def projection_matrix(D, xyz, labels=None, projection=np.max, max_n_ticks=4,
factor=3):
""" Generate a projection matrix plot
Parameters
----------
D: array_like
N-dimensional data to plot, `D.shape` should be `(n1, n2,..., nn)`,
where `ni`, is the number of grid points along dimension `i`.
xyz: list
List of 1-dimensional arrays of coordinates. `xyz[i]` should have
length `ni` (see help for `D`).
labels: list
N+1 length list of labels; the first N correspond to the coordinates
labels, the final label is for the dependent variable.
projection: func
Function to use for projection, must take an `axis` argument. Default
is `np.max()`, to project out a slice along the maximum.
max_n_ticks: int
Number of ticks for x and y axis of the `pcolormesh` plots
factor: float
Controls the size of one window
Returns
-------
fig, axes:
The figure and NxN set of axes
"""
ndim = D.ndim
lbdim = 0.5 * factor # size of left/bottom margin
trdim = 0.2 * factor # size of top/right margin
whspace = 0.05 # w/hspace size
plotdim = factor * ndim + factor * (ndim - 1.) * whspace
dim = lbdim + plotdim + trdim
fig, axes = plt.subplots(ndim, ndim, figsize=(dim, dim))
# Format the figure.
lb = lbdim / dim
tr = (lbdim + plotdim) / dim
fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr,
wspace=whspace, hspace=whspace)
for i in range(ndim):
projection_1D(axes[i, i], xyz[i], D, i, projection=projection)
for j in range(ndim):
ax = axes[i, j]
if j > i:
ax.set_frame_on(False)
ax.set_xticks([])
ax.set_yticks([])
continue
ax.get_shared_x_axes().join(axes[ndim-1, j], ax)
if i < ndim - 1:
ax.set_xticklabels([])
if j < i:
ax.get_shared_y_axes().join(axes[i, i-1], ax)
if j > 0:
ax.set_yticklabels([])
if j == i:
continue
ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j,
projection=projection)
if labels:
for i in range(ndim):
axes[-1, i].set_xlabel(labels[i])
if i > 0:
axes[i, 0].set_ylabel(labels[i])
axes[i, i].set_ylabel(labels[-1])
return fig, axes
def projection_2D(ax, x, y, D, xidx, yidx, projection):
flat_idxs = range(D.ndim)
flat_idxs.remove(xidx)
flat_idxs.remove(yidx)
D2D = projection(D, axis=tuple(flat_idxs))
X, Y = np.meshgrid(x, y, indexing='ij')
pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max())
return ax, pax
def projection_1D(ax, x, D, xidx, projection):
flat_idxs = range(D.ndim)
flat_idxs.remove(xidx)
D1D = projection(D, axis=tuple(flat_idxs))
ax.plot(x, D1D)
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
return ax
setup.py 0 → 100644
#!/usr/bin/env python
from distutils.core import setup
setup(name='projection_matrix',
version='0.1',
author='Gregory Ashton',
author_email='gregory.ashton@ligo.org',
packages=['projection_matrix'],
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment