projection_matrix.py 5.14 KB
Newer Older
1
2
3
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
4
from scipy.misc import logsumexp
5
6


7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def log_mean(loga, axis):
    """ Calculate the log(<a>) mean

    Given `N` logged value `log`, calculate the log_mean
    `log(<loga>)=log(sum(np.exp(loga))) - log(N)`. Useful for marginalizing
    over logged likelihoods for example.

    Parameters
    ----------
    loga: array_like
        Input_array.
    axies: None or int or type of ints, optional
        Axis or axes over which the sum is taken. By default axis is None, and
        all elements are summed.
    Returns
    -------
    log_mean: ndarry
        The logged average value (shape loga.shape)
    """
    loga = np.array(loga)
    N = np.prod([loga.shape[i] for i in axis])
    return logsumexp(loga, axis) - np.log(N)


def max_slice(D, axis):
32
33
34
35
36
37
38
39
40
    """ Return the slice along the given axis """
    idxs = [range(D.shape[j]) for j in range(D.ndim)]
    max_idx = list(np.unravel_index(D.argmax(), D.shape))
    for k in np.atleast_1d(axis):
        idxs[k] = [max_idx[k]]
    res = np.squeeze(D[np.ix_(*tuple(idxs))])
    return res


41
42
43
44
45
46
47
48
49
50
51
def idx_array_slice(D, axis, slice_idx):
    """ Return the slice along the given axis """
    idxs = [range(D.shape[j]) for j in range(D.ndim)]
    for k in np.atleast_1d(axis):
        idxs[k] = [slice_idx[k]]
    res = np.squeeze(D[np.ix_(*tuple(idxs))])
    return res


def projection_matrix(D, xyz, labels=None, projection='max_slice',
                      max_n_ticks=4, factor=3, **kwargs):
52
53
54
55
56
57
58
59
60
61
62
63
64
    """ Generate a projection matrix plot

    Parameters
    ----------
    D: array_like
        N-dimensional data to plot, `D.shape` should be  `(n1, n2,..., nn)`,
        where `ni`, is the number of grid points along dimension `i`.
    xyz: list
        List of 1-dimensional arrays of coordinates. `xyz[i]` should have
        length `ni` (see help for `D`).
    labels: list
        N+1 length list of labels; the first N correspond to the coordinates
        labels, the final label is for the dependent variable.
65
66
67
68
69
    projection: str or func
        If a string, one of `{"log_mean", "max_slice"} to use inbuilt functions
        to calculate either the logged mean or maximum slice projection. Else
        a function to use for projection, must take an `axis` argument. Default
        is `projection_matrix.max_slice()`, to project out a slice along the
70
        maximum.
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    max_n_ticks: int
        Number of ticks for x and y axis of the `pcolormesh` plots
    factor: float
        Controls the size of one window

    Returns
    -------
    fig, axes:
        The figure and NxN set of axes

    """
    ndim = D.ndim
    lbdim = 0.5 * factor   # size of left/bottom margin
    trdim = 0.2 * factor   # size of top/right margin
    whspace = 0.05         # w/hspace size
    plotdim = factor * ndim + factor * (ndim - 1.) * whspace
    dim = lbdim + plotdim + trdim

89
90
91
92
93
94
95
96
    if type(projection) == str:
        if projection in ['log_mean']:
            projection = log_mean
        elif projection in ['max_slice']:
            projection = max_slice
        else:
            raise ValueError("Projection {} not understood".format(projection))

97
98
99
100
101
102
103
104
    fig, axes = plt.subplots(ndim, ndim, figsize=(dim, dim))

    # Format the figure.
    lb = lbdim / dim
    tr = (lbdim + plotdim) / dim
    fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr,
                        wspace=whspace, hspace=whspace)
    for i in range(ndim):
105
106
        projection_1D(
            axes[i, i], xyz[i], D, i, projection=projection, **kwargs)
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        for j in range(ndim):
            ax = axes[i, j]

            if j > i:
                ax.set_frame_on(False)
                ax.set_xticks([])
                ax.set_yticks([])
                continue

            ax.get_shared_x_axes().join(axes[ndim-1, j], ax)
            if i < ndim - 1:
                ax.set_xticklabels([])
            if j < i:
                ax.get_shared_y_axes().join(axes[i, i-1], ax)
                if j > 0:
                    ax.set_yticklabels([])
            if j == i:
                continue

            ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))
            ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="upper"))

            ax, pax = projection_2D(ax, xyz[i], xyz[j], D, i, j,
130
                                    projection=projection, **kwargs)
131
132
133
134
135
136
137
138
139
140

    if labels:
        for i in range(ndim):
            axes[-1, i].set_xlabel(labels[i])
            if i > 0:
                axes[i, 0].set_ylabel(labels[i])
            axes[i, i].set_ylabel(labels[-1])
    return fig, axes


141
def projection_2D(ax, x, y, D, xidx, yidx, projection, **kwargs):
142
143
144
    flat_idxs = range(D.ndim)
    flat_idxs.remove(xidx)
    flat_idxs.remove(yidx)
145
    D2D = projection(D, axis=tuple(flat_idxs), **kwargs)
146
147
148
149
150
    X, Y = np.meshgrid(x, y, indexing='ij')
    pax = ax.pcolormesh(Y, X, D2D.T, vmin=D.min(), vmax=D.max())
    return ax, pax


151
def projection_1D(ax, x, D, xidx, projection, **kwargs):
152
153
    flat_idxs = range(D.ndim)
    flat_idxs.remove(xidx)
154
    D1D = projection(D, axis=tuple(flat_idxs), **kwargs)
155
156
157
158
    ax.plot(x, D1D)
    ax.yaxis.tick_right()
    ax.yaxis.set_label_position("right")
    return ax
159
160
161