projection_matrix.py 5.14 KB
 Gregory Ashton committed Oct 15, 2017 1 2 3 ``````import numpy as np import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator `````` Gregory Ashton committed Oct 30, 2017 4 ``````from scipy.misc import logsumexp `````` Gregory Ashton committed Oct 15, 2017 5 6 `````` `````` Gregory Ashton committed Oct 30, 2017 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 ``````def log_mean(loga, axis): """ Calculate the log() mean Given `N` logged value `log`, calculate the log_mean `log()=log(sum(np.exp(loga))) - log(N)`. Useful for marginalizing over logged likelihoods for example. Parameters ---------- loga: array_like Input_array. axies: None or int or type of ints, optional Axis or axes over which the sum is taken. By default axis is None, and all elements are summed. Returns ------- log_mean: ndarry The logged average value (shape loga.shape) """ loga = np.array(loga) N = np.prod([loga.shape[i] for i in axis]) return logsumexp(loga, axis) - np.log(N) def max_slice(D, axis): `````` Gregory Ashton committed Oct 15, 2017 32 33 34 35 36 37 38 39 40 `````` """ Return the slice along the given axis """ idxs = [range(D.shape[j]) for j in range(D.ndim)] max_idx = list(np.unravel_index(D.argmax(), D.shape)) for k in np.atleast_1d(axis): idxs[k] = [max_idx[k]] res = np.squeeze(D[np.ix_(*tuple(idxs))]) return res `````` Gregory Ashton committed Oct 30, 2017 41 42 43 44 45 46 47 48 49 50 51 ``````def idx_array_slice(D, axis, slice_idx): """ Return the slice along the given axis """ idxs = [range(D.shape[j]) for j in range(D.ndim)] for k in np.atleast_1d(axis): idxs[k] = [slice_idx[k]] res = np.squeeze(D[np.ix_(*tuple(idxs))]) return res def projection_matrix(D, xyz, labels=None, projection='max_slice', max_n_ticks=4, factor=3, **kwargs): `````` Gregory Ashton committed Oct 15, 2017 52 53 54 55 56 57 58 59 60 61 62 63 64 `````` """ Generate a projection matrix plot Parameters ---------- D: array_like N-dimensional data to plot, `D.shape` should be `(n1, n2,..., nn)`, where `ni`, is the number of grid points along dimension `i`. xyz: list List of 1-dimensional arrays of coordinates. `xyz[i]` should have length `ni` (see help for `D`). labels: list N+1 length list of labels; the first N correspond to the coordinates labels, the final label is for the dependent variable. `````` Gregory Ashton committed Oct 30, 2017 65 66 67 68 69 `````` projection: str or func If a string, one of `{"log_mean", "max_slice"} to use inbuilt functions to calculate either the logged mean or maximum slice projection. Else a function to use for projection, must take an `axis` argument. Default is `projection_matrix.max_slice()`, to project out a slice along the `````` Gregory Ashton committed Oct 15, 2017 70 `````` maximum. `````` Gregory Ashton committed Oct 15, 2017 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 `````` max_n_ticks: int Number of ticks for x and y axis of the `pcolormesh` plots factor: float Controls the size of one window Returns ------- fig, axes: The figure and NxN set of axes """ ndim = D.ndim lbdim = 0.5 * factor # size of left/bottom margin trdim = 0.2 * factor # size of top/right margin whspace = 0.05 # w/hspace size plotdim = factor * ndim + factor * (ndim - 1.) * whspace dim = lbdim + plotdim + trdim `````` Gregory Ashton committed Oct 30, 2017 89 90 91 92 93 94 95 96 `````` if type(projection) == str: if projection in ['log_mean']: projection = log_mean elif projection in ['max_slice']: projection = max_slice else: raise ValueError("Projection {} not understood".format(projection)) `````` Gregory Ashton committed Oct 15, 2017 97 98 99 100 101 102 103 104 `````` fig, axes = plt.subplots(ndim, ndim, figsize=(dim, dim)) # Format the figure. lb = lbdim / dim tr = (lbdim + plotdim) / dim fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr, wspace=whspace, hspace=whspace) for i in range(ndim): `````` Gregory Ashton committed Oct 30, 2017 105 106 `````` projection_1D( axes[i, i], xyz[i], D, i, projection=projection, **kwargs) `````` Gregory Ashton committed Oct 15, 2017 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 `````` for j in range(ndim): ax = axes[i, j] if j > i: ax.set_frame_on(False) ax.set_xticks([]) ax.set_yticks([]) continue ax.get_shared_x_axes().join(axes[ndim-1, j], ax) if i < ndim - 1: ax.set_xticklabels([]) if j < i: ax.get_shared_y_axes().join(axes[i, i-1], ax) if j > 0: ax.set_yticklabels([]) if j == i: continue ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper")) ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper")) ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j, `````` Gregory Ashton committed Oct 30, 2017 130 `````` projection=projection, **kwargs) `````` Gregory Ashton committed Oct 15, 2017 131 132 133 134 135 136 137 138 139 140 `````` if labels: for i in range(ndim): axes[-1, i].set_xlabel(labels[i]) if i > 0: axes[i, 0].set_ylabel(labels[i]) axes[i, i].set_ylabel(labels[-1]) return fig, axes `````` Gregory Ashton committed Oct 30, 2017 141 ``````def projection_2D(ax, x, y, D, xidx, yidx, projection, **kwargs): `````` Gregory Ashton committed Oct 15, 2017 142 143 144 `````` flat_idxs = range(D.ndim) flat_idxs.remove(xidx) flat_idxs.remove(yidx) `````` Gregory Ashton committed Oct 30, 2017 145 `````` D2D = projection(D, axis=tuple(flat_idxs), **kwargs) `````` Gregory Ashton committed Oct 15, 2017 146 147 148 149 150 `````` X, Y = np.meshgrid(x, y, indexing='ij') pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max()) return ax, pax `````` Gregory Ashton committed Oct 30, 2017 151 ``````def projection_1D(ax, x, D, xidx, projection, **kwargs): `````` Gregory Ashton committed Oct 15, 2017 152 153 `````` flat_idxs = range(D.ndim) flat_idxs.remove(xidx) `````` Gregory Ashton committed Oct 30, 2017 154 `````` D1D = projection(D, axis=tuple(flat_idxs), **kwargs) `````` Gregory Ashton committed Oct 15, 2017 155 156 157 158 `````` ax.plot(x, D1D) ax.yaxis.tick_right() ax.yaxis.set_label_position("right") return ax `````` Gregory Ashton committed Oct 15, 2017 159 160 161 `````` ``````