Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Gregory Ashton
PyFstat
Commits
afbb9815
Commit
afbb9815
authored
Aug 08, 2017
by
Gregory Ashton
Browse files
Reorganisation and cleanup of convergence testing
- Adds autocorrelation attempt (using PR 223 to emcee)
parent
50e20741
Changes
1
Hide whitespace changes
Inline
Side-by-side
pyfstat/mcmc_based_searches.py
View file @
afbb9815
...
...
@@ -239,61 +239,63 @@ class MCMCSearch(core.BaseSearchClass):
pass
return
sampler
def
setup_convergence_testing
(
self
,
convergence_period
=
10
,
convergence_length
=
10
,
convergence_burnin_fraction
=
0.25
,
convergence_threshold_number
=
10
,
convergence_threshold
=
1.2
,
convergence_prod_threshold
=
2
,
convergence_plot_upper_lim
=
2
,
convergence_early_stopping
=
True
):
def
setup_burnin_convergence_testing
(
self
,
n
=
10
,
test_type
=
'autocorr'
,
windowed
=
False
,
**
kwargs
):
"""
If called, convergence testing is used during the MCMC simulation
This uses the Gelmanr-Rubin statistic based on the ratio of between and
within walkers variance. The original statistic was developed for
multiple (independent) MCMC simulations, in this context we simply use
the walkers
Parameters
----------
convergence_period: int
period (in number of steps) at which to test convergence
convergence_length: int
number of steps to use in testing convergence - this should be
large enough to measure the variance, but if it is too long
this will result in incorect early convergence tests
convergence_burnin_fraction: float [0, 1]
the fraction of the burn-in period after which to start testing
convergence_threshold_number: int
the number of consecutive times where the test passes after which
to break the burn-in and go to production
convergence_threshold: float
the threshold to use in diagnosing convergence. Gelman & Rubin
recomend a value of 1.2, 1.1 for strict convergence
convergence_prod_threshold: float
the threshold to test the production values with
convergence_plot_upper_lim: float
the upper limit to use in the diagnostic plot
convergence_early_stopping: bool
if true, stop the burnin early if convergence is reached
n: int
Number of steps after which to test convergence
test_type: str ['autocorr', 'GR']
If 'autocorr' use the exponential autocorrelation time (kwargs
passed to `get_autocorr_convergence`). If 'GR' use the Gelman-Rubin
statistic (kwargs passed to `get_GR_convergence`)
windowed: bool
If True, only calculate the convergence test in a window of length
`n`
"""
if
convergence_length
>
convergence_period
:
raise
ValueError
(
'convergence_length must be < convergence_period'
)
logging
.
info
(
'Setting up convergence testing'
)
self
.
convergence_
length
=
convergence_length
self
.
convergence_
period
=
convergence_perio
d
self
.
convergence_
burnin_fraction
=
convergence_burnin_fraction
self
.
convergence_
prod_threshold
=
convergence_prod_threshold
self
.
convergence_
n
=
n
self
.
convergence_
windowed
=
windowe
d
self
.
convergence_
test_type
=
test_type
self
.
convergence_
kwargs
=
kwargs
self
.
convergence_diagnostic
=
[]
self
.
convergence_diagnosticx
=
[]
self
.
convergence_threshold_number
=
convergence_threshold_number
self
.
convergence_threshold
=
convergence_threshold
self
.
convergence_number
=
0
self
.
convergence_plot_upper_lim
=
convergence_plot_upper_lim
self
.
convergence_early_stopping
=
convergence_early_stopping
def
_get_convergence_statistic
(
self
,
i
,
sampler
):
s
=
sampler
.
chain
[
0
,
:,
i
-
self
.
convergence_length
+
1
:
i
+
1
,
:]
N
=
float
(
self
.
convergence_length
)
if
test_type
in
[
'autocorr'
]:
self
.
_get_convergence_test
=
self
.
test_autocorr_convergence
elif
test_type
in
[
'GR'
]:
self
.
_get_convergence_test
=
self
.
test_GR_convergence
else
:
raise
ValueError
(
'test_type {} not understood'
.
format
(
test_type
))
def
test_autocorr_convergence
(
self
,
i
,
sampler
,
test
=
True
,
n_cut
=
5
):
try
:
acors
=
np
.
zeros
((
self
.
ntemps
,
self
.
ndim
))
for
temp
in
range
(
self
.
ntemps
):
if
self
.
convergence_windowed
:
j
=
i
-
self
.
convergence_n
else
:
j
=
0
x
=
np
.
mean
(
sampler
.
chain
[
temp
,
:,
j
:
i
,
:],
axis
=
0
)
acors
[
temp
,
:]
=
emcee
.
autocorr
.
exponential_time
(
x
)
c
=
np
.
max
(
acors
,
axis
=
0
)
except
emcee
.
autocorr
.
AutocorrError
:
c
=
np
.
zeros
(
self
.
ndim
)
+
np
.
nan
self
.
convergence_diagnosticx
.
append
(
i
-
self
.
convergence_n
/
2.
)
self
.
convergence_diagnostic
.
append
(
list
(
c
))
if
test
:
return
i
>
n_cut
*
np
.
max
(
c
)
def
test_GR_convergence
(
self
,
i
,
sampler
,
test
=
True
,
R
=
1.1
):
if
self
.
convergence_windowed
:
s
=
sampler
.
chain
[
0
,
:,
i
-
self
.
convergence_n
+
1
:
i
+
1
,
:]
else
:
s
=
sampler
.
chain
[
0
,
:,
:
i
+
1
,
:]
N
=
float
(
self
.
convergence_n
)
M
=
float
(
self
.
nwalkers
)
W
=
np
.
mean
(
np
.
var
(
s
,
axis
=
1
),
axis
=
0
)
per_walker_mean
=
np
.
mean
(
s
,
axis
=
1
)
...
...
@@ -302,58 +304,45 @@ class MCMCSearch(core.BaseSearchClass):
Vhat
=
(
N
-
1
)
/
N
*
W
+
(
M
+
1
)
/
(
M
*
N
)
*
B
c
=
np
.
sqrt
(
Vhat
/
W
)
self
.
convergence_diagnostic
.
append
(
c
)
self
.
convergence_diagnosticx
.
append
(
i
-
self
.
convergence_length
/
2
)
return
c
self
.
convergence_diagnosticx
.
append
(
i
-
self
.
convergence_n
/
2.
)
def
_burnin_convergence_test
(
self
,
i
,
sampler
,
nburn
):
if
i
<
self
.
convergence_burnin_fraction
*
nburn
:
return
False
if
np
.
mod
(
i
+
1
,
self
.
convergence_period
)
!=
0
:
if
test
and
np
.
max
(
c
)
<
R
:
return
True
else
:
return
False
c
=
self
.
_get_convergence_statistic
(
i
,
sampler
)
if
np
.
all
(
c
<
self
.
convergence_threshold
):
self
.
convergence_number
+=
1
def
_test_convergence
(
self
,
i
,
sampler
,
**
kwargs
):
if
np
.
mod
(
i
+
1
,
self
.
convergence_n
)
==
0
:
return
self
.
_get_convergence_test
(
i
,
sampler
,
**
kwargs
)
else
:
self
.
convergence_number
=
0
if
self
.
convergence_early_stopping
:
return
self
.
convergence_number
>
self
.
convergence_threshold_number
def
_prod_convergence_test
(
self
,
i
,
sampler
,
nburn
):
testA
=
i
>
nburn
+
self
.
convergence_length
testB
=
np
.
mod
(
i
+
1
,
self
.
convergence_period
)
==
0
if
testA
and
testB
:
self
.
_get_convergence_statistic
(
i
,
sampler
)
def
_check_production_convergence
(
self
,
k
):
bools
=
np
.
any
(
np
.
array
(
self
.
convergence_diagnostic
)[
k
:,
:]
>
self
.
convergence_prod_threshold
,
axis
=
1
)
if
np
.
any
(
bools
):
logging
.
warning
(
'{} convergence tests in the production run of {} failed'
.
format
(
np
.
sum
(
bools
),
len
(
bools
)))
return
False
def
_run_sampler_with_conv_test
(
self
,
sampler
,
p0
,
nprod
=
0
,
nburn
=
0
):
logging
.
info
(
'Running {} burn-in steps with convergence testing'
.
format
(
nburn
))
iterator
=
tqdm
(
sampler
.
sample
(
p0
,
iterations
=
nburn
),
total
=
nburn
)
for
i
,
output
in
enumerate
(
iterator
):
if
self
.
_test_convergence
(
i
,
sampler
,
test
=
True
,
**
self
.
convergence_kwargs
):
logging
.
info
(
'Converged at {} before max number {} of steps reached'
.
format
(
i
,
nburn
))
self
.
convergence_idx
=
i
break
iterator
.
close
()
logging
.
info
(
'Running {} production steps'
.
format
(
nprod
))
j
=
nburn
iterator
=
tqdm
(
sampler
.
sample
(
output
[
0
],
iterations
=
nprod
),
total
=
nprod
)
for
result
in
iterator
:
self
.
_test_convergence
(
j
,
sampler
,
test
=
False
,
**
self
.
convergence_kwargs
)
j
+=
1
return
sampler
def
_run_sampler
(
self
,
sampler
,
p0
,
nprod
=
0
,
nburn
=
0
):
if
hasattr
(
self
,
'convergence_period'
):
logging
.
info
(
'Running {} burn-in steps with convergence testing'
.
format
(
nburn
))
iterator
=
tqdm
(
sampler
.
sample
(
p0
,
iterations
=
nburn
),
total
=
nburn
)
for
i
,
output
in
enumerate
(
iterator
):
if
self
.
_burnin_convergence_test
(
i
,
sampler
,
nburn
):
logging
.
info
(
'Converged at {} before max number {} of steps reached'
.
format
(
i
,
nburn
))
self
.
convergence_idx
=
i
break
iterator
.
close
()
logging
.
info
(
'Running {} production steps'
.
format
(
nprod
))
j
=
nburn
k
=
len
(
self
.
convergence_diagnostic
)
for
result
in
tqdm
(
sampler
.
sample
(
output
[
0
],
iterations
=
nprod
),
total
=
nprod
):
self
.
_prod_convergence_test
(
j
,
sampler
,
nburn
)
j
+=
1
self
.
_check_production_convergence
(
k
)
if
hasattr
(
self
,
'convergence_n'
):
self
.
_run_sampler_with_conv_test
(
sampler
,
p0
,
nprod
,
nburn
)
else
:
for
result
in
tqdm
(
sampler
.
sample
(
p0
,
iterations
=
nburn
+
nprod
),
total
=
nburn
+
nprod
):
...
...
@@ -956,9 +945,11 @@ class MCMCSearch(core.BaseSearchClass):
zorder
=-
10
)
ax
.
plot
(
c_x
[
break_idx
:],
c_y
[
break_idx
:,
i
],
'-C0'
,
zorder
=-
10
)
ax
.
set_ylabel
(
'PSRF'
)
if
self
.
convergence_test_type
==
'autocorr'
:
ax
.
set_ylabel
(
r
'$\tau_\mathrm{exp}$'
)
elif
self
.
convergence_test_type
==
'GR'
:
ax
.
set_ylabel
(
'PSRF'
)
ax
.
ticklabel_format
(
useOffset
=
False
)
ax
.
set_ylim
(
0.5
,
self
.
convergence_plot_upper_lim
)
else
:
axes
[
0
].
ticklabel_format
(
useOffset
=
False
,
axis
=
'y'
)
cs
=
chain
[:,
:,
temp
].
T
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment