Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Gregory Ashton
PyFstat
Commits
9c2a9c57
Commit
9c2a9c57
authored
Nov 29, 2017
by
Gregory Ashton
Browse files
Add transform_dict methods to plot_walkers
parent
ecf6c1c2
Changes
1
Show whitespace changes
Inline
Side-by-side
pyfstat/mcmc_based_searches.py
View file @
9c2a9c57
...
...
@@ -502,7 +502,6 @@ class MCMCSearch(core.BaseSearchClass):
sampler
=
self
.
_run_sampler
(
sampler
,
p0
,
nburn
=
n
,
window
=
window
)
if
create_plots
:
fig
,
axes
=
self
.
_plot_walkers
(
sampler
,
symbols
=
self
.
theta_symbols
,
**
kwargs
)
fig
.
tight_layout
()
fig
.
savefig
(
'{}/{}_init_{}_walkers.png'
.
format
(
...
...
@@ -522,8 +521,7 @@ class MCMCSearch(core.BaseSearchClass):
nburn
+
nprod
))
sampler
=
self
.
_run_sampler
(
sampler
,
p0
,
nburn
=
nburn
,
nprod
=
nprod
)
if
create_plots
:
fig
,
axes
=
self
.
_plot_walkers
(
sampler
,
symbols
=
self
.
theta_symbols
,
nprod
=
nprod
,
**
kwargs
)
fig
,
axes
=
self
.
_plot_walkers
(
sampler
,
nprod
=
nprod
,
**
kwargs
)
fig
.
tight_layout
()
fig
.
savefig
(
'{}/{}_walkers.png'
.
format
(
self
.
outdir
,
self
.
label
),
)
...
...
@@ -603,7 +601,7 @@ class MCMCSearch(core.BaseSearchClass):
return
samples
def
_get_labels
(
self
):
def
_get_labels
(
self
,
newline_units
=
False
):
""" Combine the units, symbols and rescaling to give labels """
labels
=
[]
...
...
@@ -620,7 +618,10 @@ class MCMCSearch(core.BaseSearchClass):
if
'unit'
in
self
.
transform_dictionary
[
key
]:
u
=
self
.
transform_dictionary
[
key
][
'unit'
]
if
label
is
None
:
if
newline_units
:
label
=
'{}
\n
[{}]'
.
format
(
s
,
u
)
else
:
label
=
'{} [{}]'
.
format
(
s
,
u
)
labels
.
append
(
label
)
return
labels
...
...
@@ -694,7 +695,7 @@ class MCMCSearch(core.BaseSearchClass):
fig
,
axes
=
fig_and_axes
samples_plt
=
copy
.
copy
(
self
.
samples
)
labels
=
self
.
_get_labels
()
labels
=
self
.
_get_labels
(
newline_units
=
True
)
samples_plt
=
self
.
_scale_samples
(
samples_plt
,
self
.
theta_keys
)
...
...
@@ -963,9 +964,11 @@ class MCMCSearch(core.BaseSearchClass):
def
_plot_walkers
(
self
,
sampler
,
symbols
=
None
,
alpha
=
0.8
,
color
=
"k"
,
temp
=
0
,
lw
=
0.1
,
nprod
=
0
,
add_det_stat_burnin
=
False
,
fig
=
None
,
axes
=
None
,
xoffset
=
0
,
plot_det_stat
=
False
,
context
=
'ggplot'
,
subtractions
=
None
,
labelpad
=
0.0
5
):
context
=
'ggplot'
,
labelpad
=
5
):
""" Plot all the chains from a sampler """
if
symbols
is
None
:
symbols
=
self
.
_get_labels
()
if
context
not
in
plt
.
style
.
available
:
raise
ValueError
((
'The requested context {} is not available; please select a'
...
...
@@ -977,7 +980,7 @@ class MCMCSearch(core.BaseSearchClass):
shape
=
sampler
.
chain
.
shape
if
len
(
shape
)
==
3
:
nwalkers
,
nsteps
,
ndim
=
shape
chain
=
sampler
.
chain
[:,
:,
:]
chain
=
sampler
.
chain
[:,
:,
:]
.
copy
()
if
len
(
shape
)
==
4
:
ntemps
,
nwalkers
,
nsteps
,
ndim
=
shape
if
temp
<
ntemps
:
...
...
@@ -985,13 +988,11 @@ class MCMCSearch(core.BaseSearchClass):
else
:
raise
ValueError
((
"Requested temperature {} outside of"
"available range"
).
format
(
temp
))
chain
=
sampler
.
chain
[
temp
,
:,
:,
:]
chain
=
sampler
.
chain
[
temp
,
:,
:,
:]
.
copy
()
if
subtractions
is
None
:
subtractions
=
[
0
for
i
in
range
(
ndim
)]
else
:
if
len
(
subtractions
)
!=
self
.
ndim
:
raise
ValueError
(
'subtractions must be of length ndim'
)
samples
=
chain
.
reshape
((
nwalkers
*
nsteps
,
ndim
))
samples
=
self
.
_scale_samples
(
samples
,
self
.
theta_keys
)
chain
=
chain
.
reshape
((
nwalkers
,
nsteps
,
ndim
))
if
plot_det_stat
:
extra_subplots
=
1
...
...
@@ -1017,23 +1018,24 @@ class MCMCSearch(core.BaseSearchClass):
cs
=
chain
[:,
:,
i
].
T
if
burnin_idx
>
0
:
axes
[
i
].
plot
(
xoffset
+
idxs
[:
last_idx
+
1
],
cs
[:
last_idx
+
1
]
-
subtractions
[
i
]
,
cs
[:
last_idx
+
1
],
color
=
"C3"
,
alpha
=
alpha
,
lw
=
lw
)
axes
[
i
].
axvline
(
xoffset
+
last_idx
,
color
=
'k'
,
ls
=
'--'
,
lw
=
0.5
)
axes
[
i
].
plot
(
xoffset
+
idxs
[
burnin_idx
:],
cs
[
burnin_idx
:]
-
subtractions
[
i
]
,
cs
[
burnin_idx
:],
color
=
"k"
,
alpha
=
alpha
,
lw
=
lw
)
axes
[
i
].
set_xlim
(
0
,
xoffset
+
idxs
[
-
1
])
if
symbols
:
if
subtractions
[
i
]
==
0
:
axes
[
i
].
set_ylabel
(
symbols
[
i
],
labelpad
=
labelpad
)
else
:
axes
[
i
].
set_ylabel
(
symbols
[
i
]
+
'$-$'
+
symbols
[
i
]
+
'$^\mathrm{s}$'
,
labelpad
=
labelpad
)
#if subtractions[i] == 0:
# axes[i].set_ylabel(symbols[i], labelpad=labelpad)
#else:
# axes[i].set_ylabel(
# symbols[i]+'$-$'+symbols[i]+'$^\mathrm{s}$',
# labelpad=labelpad)
# if hasattr(self, 'convergence_diagnostic'):
# ax = axes[i].twinx()
...
...
@@ -2120,7 +2122,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
if
create_plots
:
fig
,
axes
=
self
.
_plot_walkers
(
sampler
,
symbols
=
self
.
theta_symbols
,
fig
=
fig
,
axes
=
axes
,
sampler
,
fig
=
fig
,
axes
=
axes
,
nprod
=
nprod
,
xoffset
=
nsteps_total
,
**
kwargs
)
for
ax
in
axes
[:
self
.
ndim
]:
ax
.
axvline
(
nsteps_total
,
color
=
'k'
,
ls
=
'--'
,
lw
=
0.25
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment