Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
David Keitel
gridcorner
Commits
bb07e608
Commit
bb07e608
authored
Oct 30, 2017
by
Gregory Ashton
Browse files
Adds new projection methods and improve docs
parent
78457b59
Changes
2
Hide whitespace changes
Inline
Side-by-side
projection_matrix/__init__.py
View file @
bb07e608
from
.projection_matrix
import
projection_matrix
,
slice_max
from
.projection_matrix
import
projection_matrix
projection_matrix/projection_matrix.py
View file @
bb07e608
import
numpy
as
np
import
matplotlib.pyplot
as
plt
from
matplotlib.ticker
import
MaxNLocator
from
scipy.misc
import
logsumexp
def
slice_max
(
D
,
axis
):
def
log_mean
(
loga
,
axis
):
""" Calculate the log(<a>) mean
Given `N` logged value `log`, calculate the log_mean
`log(<loga>)=log(sum(np.exp(loga))) - log(N)`. Useful for marginalizing
over logged likelihoods for example.
Parameters
----------
loga: array_like
Input_array.
axies: None or int or type of ints, optional
Axis or axes over which the sum is taken. By default axis is None, and
all elements are summed.
Returns
-------
log_mean: ndarry
The logged average value (shape loga.shape)
"""
loga
=
np
.
array
(
loga
)
N
=
np
.
prod
([
loga
.
shape
[
i
]
for
i
in
axis
])
return
logsumexp
(
loga
,
axis
)
-
np
.
log
(
N
)
def
max_slice
(
D
,
axis
):
""" Return the slice along the given axis """
idxs
=
[
range
(
D
.
shape
[
j
])
for
j
in
range
(
D
.
ndim
)]
max_idx
=
list
(
np
.
unravel_index
(
D
.
argmax
(),
D
.
shape
))
...
...
@@ -13,8 +38,17 @@ def slice_max(D, axis):
return
res
def
projection_matrix
(
D
,
xyz
,
labels
=
None
,
projection
=
slice_max
,
max_n_ticks
=
4
,
factor
=
3
):
def
idx_array_slice
(
D
,
axis
,
slice_idx
):
""" Return the slice along the given axis """
idxs
=
[
range
(
D
.
shape
[
j
])
for
j
in
range
(
D
.
ndim
)]
for
k
in
np
.
atleast_1d
(
axis
):
idxs
[
k
]
=
[
slice_idx
[
k
]]
res
=
np
.
squeeze
(
D
[
np
.
ix_
(
*
tuple
(
idxs
))])
return
res
def
projection_matrix
(
D
,
xyz
,
labels
=
None
,
projection
=
'max_slice'
,
max_n_ticks
=
4
,
factor
=
3
,
**
kwargs
):
""" Generate a projection matrix plot
Parameters
...
...
@@ -28,9 +62,11 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
labels: list
N+1 length list of labels; the first N correspond to the coordinates
labels, the final label is for the dependent variable.
projection: func
Function to use for projection, must take an `axis` argument. Default
is `projection_matrix.slice_max()`, to project out a slice along the
projection: str or func
If a string, one of `{"log_mean", "max_slice"} to use inbuilt functions
to calculate either the logged mean or maximum slice projection. Else
a function to use for projection, must take an `axis` argument. Default
is `projection_matrix.max_slice()`, to project out a slice along the
maximum.
max_n_ticks: int
Number of ticks for x and y axis of the `pcolormesh` plots
...
...
@@ -50,6 +86,14 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
plotdim
=
factor
*
ndim
+
factor
*
(
ndim
-
1.
)
*
whspace
dim
=
lbdim
+
plotdim
+
trdim
if
type
(
projection
)
==
str
:
if
projection
in
[
'log_mean'
]:
projection
=
log_mean
elif
projection
in
[
'max_slice'
]:
projection
=
max_slice
else
:
raise
ValueError
(
"Projection {} not understood"
.
format
(
projection
))
fig
,
axes
=
plt
.
subplots
(
ndim
,
ndim
,
figsize
=
(
dim
,
dim
))
# Format the figure.
...
...
@@ -58,7 +102,8 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
fig
.
subplots_adjust
(
left
=
lb
,
bottom
=
lb
,
right
=
tr
,
top
=
tr
,
wspace
=
whspace
,
hspace
=
whspace
)
for
i
in
range
(
ndim
):
projection_1D
(
axes
[
i
,
i
],
xyz
[
i
],
D
,
i
,
projection
=
projection
)
projection_1D
(
axes
[
i
,
i
],
xyz
[
i
],
D
,
i
,
projection
=
projection
,
**
kwargs
)
for
j
in
range
(
ndim
):
ax
=
axes
[
i
,
j
]
...
...
@@ -82,7 +127,7 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
ax
.
yaxis
.
set_major_locator
(
MaxNLocator
(
max_n_ticks
,
prune
=
"upper"
))
ax
,
pax
=
projection_2D
(
ax
,
xyz
[
i
],
xyz
[
j
],
D
,
i
,
j
,
projection
=
projection
)
projection
=
projection
,
**
kwargs
)
if
labels
:
for
i
in
range
(
ndim
):
...
...
@@ -93,20 +138,20 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
return
fig
,
axes
def
projection_2D
(
ax
,
x
,
y
,
D
,
xidx
,
yidx
,
projection
):
def
projection_2D
(
ax
,
x
,
y
,
D
,
xidx
,
yidx
,
projection
,
**
kwargs
):
flat_idxs
=
range
(
D
.
ndim
)
flat_idxs
.
remove
(
xidx
)
flat_idxs
.
remove
(
yidx
)
D2D
=
projection
(
D
,
axis
=
tuple
(
flat_idxs
))
D2D
=
projection
(
D
,
axis
=
tuple
(
flat_idxs
)
,
**
kwargs
)
X
,
Y
=
np
.
meshgrid
(
x
,
y
,
indexing
=
'ij'
)
pax
=
ax
.
pcolormesh
(
Y
,
X
,
D2D
.
T
,
vmin
=
D
.
min
(),
vmax
=
D
.
max
())
return
ax
,
pax
def
projection_1D
(
ax
,
x
,
D
,
xidx
,
projection
):
def
projection_1D
(
ax
,
x
,
D
,
xidx
,
projection
,
**
kwargs
):
flat_idxs
=
range
(
D
.
ndim
)
flat_idxs
.
remove
(
xidx
)
D1D
=
projection
(
D
,
axis
=
tuple
(
flat_idxs
))
D1D
=
projection
(
D
,
axis
=
tuple
(
flat_idxs
)
,
**
kwargs
)
ax
.
plot
(
x
,
D1D
)
ax
.
yaxis
.
tick_right
()
ax
.
yaxis
.
set_label_position
(
"right"
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment