Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
G
gridcorner
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Gregory Ashton
gridcorner
Commits
bb07e608
Commit
bb07e608
authored
7 years ago
by
Gregory Ashton
Browse files
Options
Downloads
Patches
Plain Diff
Adds new projection methods and improve docs
parent
78457b59
No related branches found
No related tags found
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
projection_matrix/__init__.py
+1
-1
1 addition, 1 deletion
projection_matrix/__init__.py
projection_matrix/projection_matrix.py
+57
-12
57 additions, 12 deletions
projection_matrix/projection_matrix.py
with
58 additions
and
13 deletions
projection_matrix/__init__.py
+
1
−
1
View file @
bb07e608
from
.projection_matrix
import
projection_matrix
,
slice_max
from
.projection_matrix
import
projection_matrix
This diff is collapsed.
Click to expand it.
projection_matrix/projection_matrix.py
+
57
−
12
View file @
bb07e608
import
numpy
as
np
import
matplotlib.pyplot
as
plt
from
matplotlib.ticker
import
MaxNLocator
from
scipy.misc
import
logsumexp
def
slice_max
(
D
,
axis
):
def
log_mean
(
loga
,
axis
):
"""
Calculate the log(<a>) mean
Given `N` logged value `log`, calculate the log_mean
`log(<loga>)=log(sum(np.exp(loga))) - log(N)`. Useful for marginalizing
over logged likelihoods for example.
Parameters
----------
loga: array_like
Input_array.
axies: None or int or type of ints, optional
Axis or axes over which the sum is taken. By default axis is None, and
all elements are summed.
Returns
-------
log_mean: ndarry
The logged average value (shape loga.shape)
"""
loga
=
np
.
array
(
loga
)
N
=
np
.
prod
([
loga
.
shape
[
i
]
for
i
in
axis
])
return
logsumexp
(
loga
,
axis
)
-
np
.
log
(
N
)
def
max_slice
(
D
,
axis
):
"""
Return the slice along the given axis
"""
idxs
=
[
range
(
D
.
shape
[
j
])
for
j
in
range
(
D
.
ndim
)]
max_idx
=
list
(
np
.
unravel_index
(
D
.
argmax
(),
D
.
shape
))
...
...
@@ -13,8 +38,17 @@ def slice_max(D, axis):
return
res
def
projection_matrix
(
D
,
xyz
,
labels
=
None
,
projection
=
slice_max
,
max_n_ticks
=
4
,
factor
=
3
):
def
idx_array_slice
(
D
,
axis
,
slice_idx
):
"""
Return the slice along the given axis
"""
idxs
=
[
range
(
D
.
shape
[
j
])
for
j
in
range
(
D
.
ndim
)]
for
k
in
np
.
atleast_1d
(
axis
):
idxs
[
k
]
=
[
slice_idx
[
k
]]
res
=
np
.
squeeze
(
D
[
np
.
ix_
(
*
tuple
(
idxs
))])
return
res
def
projection_matrix
(
D
,
xyz
,
labels
=
None
,
projection
=
'
max_slice
'
,
max_n_ticks
=
4
,
factor
=
3
,
**
kwargs
):
"""
Generate a projection matrix plot
Parameters
...
...
@@ -28,9 +62,11 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
labels: list
N+1 length list of labels; the first N correspond to the coordinates
labels, the final label is for the dependent variable.
projection: func
Function to use for projection, must take an `axis` argument. Default
is `projection_matrix.slice_max()`, to project out a slice along the
projection: str or func
If a string, one of `{
"
log_mean
"
,
"
max_slice
"
} to use inbuilt functions
to calculate either the logged mean or maximum slice projection. Else
a function to use for projection, must take an `axis` argument. Default
is `projection_matrix.max_slice()`, to project out a slice along the
maximum.
max_n_ticks: int
Number of ticks for x and y axis of the `pcolormesh` plots
...
...
@@ -50,6 +86,14 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
plotdim
=
factor
*
ndim
+
factor
*
(
ndim
-
1.
)
*
whspace
dim
=
lbdim
+
plotdim
+
trdim
if
type
(
projection
)
==
str
:
if
projection
in
[
'
log_mean
'
]:
projection
=
log_mean
elif
projection
in
[
'
max_slice
'
]:
projection
=
max_slice
else
:
raise
ValueError
(
"
Projection {} not understood
"
.
format
(
projection
))
fig
,
axes
=
plt
.
subplots
(
ndim
,
ndim
,
figsize
=
(
dim
,
dim
))
# Format the figure.
...
...
@@ -58,7 +102,8 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
fig
.
subplots_adjust
(
left
=
lb
,
bottom
=
lb
,
right
=
tr
,
top
=
tr
,
wspace
=
whspace
,
hspace
=
whspace
)
for
i
in
range
(
ndim
):
projection_1D
(
axes
[
i
,
i
],
xyz
[
i
],
D
,
i
,
projection
=
projection
)
projection_1D
(
axes
[
i
,
i
],
xyz
[
i
],
D
,
i
,
projection
=
projection
,
**
kwargs
)
for
j
in
range
(
ndim
):
ax
=
axes
[
i
,
j
]
...
...
@@ -82,7 +127,7 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
ax
.
yaxis
.
set_major_locator
(
MaxNLocator
(
max_n_ticks
,
prune
=
"
upper
"
))
ax
,
pax
=
projection_2D
(
ax
,
xyz
[
i
],
xyz
[
j
],
D
,
i
,
j
,
projection
=
projection
)
projection
=
projection
,
**
kwargs
)
if
labels
:
for
i
in
range
(
ndim
):
...
...
@@ -93,20 +138,20 @@ def projection_matrix(D, xyz, labels=None, projection=slice_max, max_n_ticks=4,
return
fig
,
axes
def
projection_2D
(
ax
,
x
,
y
,
D
,
xidx
,
yidx
,
projection
):
def
projection_2D
(
ax
,
x
,
y
,
D
,
xidx
,
yidx
,
projection
,
**
kwargs
):
flat_idxs
=
range
(
D
.
ndim
)
flat_idxs
.
remove
(
xidx
)
flat_idxs
.
remove
(
yidx
)
D2D
=
projection
(
D
,
axis
=
tuple
(
flat_idxs
))
D2D
=
projection
(
D
,
axis
=
tuple
(
flat_idxs
)
,
**
kwargs
)
X
,
Y
=
np
.
meshgrid
(
x
,
y
,
indexing
=
'
ij
'
)
pax
=
ax
.
pcolormesh
(
Y
,
X
,
D2D
.
T
,
vmin
=
D
.
min
(),
vmax
=
D
.
max
())
return
ax
,
pax
def
projection_1D
(
ax
,
x
,
D
,
xidx
,
projection
):
def
projection_1D
(
ax
,
x
,
D
,
xidx
,
projection
,
**
kwargs
):
flat_idxs
=
range
(
D
.
ndim
)
flat_idxs
.
remove
(
xidx
)
D1D
=
projection
(
D
,
axis
=
tuple
(
flat_idxs
))
D1D
=
projection
(
D
,
axis
=
tuple
(
flat_idxs
)
,
**
kwargs
)
ax
.
plot
(
x
,
D1D
)
ax
.
yaxis
.
tick_right
()
ax
.
yaxis
.
set_label_position
(
"
right
"
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment