Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PyFstat
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Model registry
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Gregory Ashton
PyFstat
Commits
8435f54d
Commit
8435f54d
authored
8 years ago
by
Gregory Ashton
Browse files
Options
Downloads
Patches
Plain Diff
Splits up the MCMC classes
This makes the MCMCGlitchSearch a subclass of the more general MCMCSearch
parent
1e111130
No related branches found
No related tags found
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
pyfstat.py
+199
-77
199 additions, 77 deletions
pyfstat.py
tests/tests.py
+6
-6
6 additions, 6 deletions
tests/tests.py
with
205 additions
and
83 deletions
pyfstat.py
+
199
−
77
View file @
8435f54d
...
@@ -346,15 +346,14 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
...
@@ -346,15 +346,14 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
return
twoFsegA
+
twoFsegB
return
twoFsegA
+
twoFsegB
class
MCMC
Glitch
Search
(
BaseSearchClass
):
class
MCMCSearch
(
BaseSearchClass
):
"""
MCMC search using
the SemiCoherentGlitchSearch
"""
"""
MCMC search using
ComputeFstat
"""
@initializer
@initializer
def
__init__
(
self
,
label
,
outdir
,
sftlabel
,
sftdir
,
theta_prior
,
tref
,
def
__init__
(
self
,
label
,
outdir
,
sftlabel
,
sftdir
,
theta_prior
,
tref
,
tstart
,
tend
,
nsteps
=
[
100
,
100
,
100
],
nwalkers
=
100
,
ntemps
=
1
,
tstart
,
tend
,
nsteps
=
[
100
,
100
,
100
],
nwalkers
=
100
,
ntemps
=
1
,
nglitch
=
0
,
theta_initial
=
None
,
minCoverFreq
=
None
,
theta_initial
=
None
,
minCoverFreq
=
None
,
maxCoverFreq
=
None
,
scatter_val
=
1e-4
,
betas
=
None
,
maxCoverFreq
=
None
,
scatter_val
=
1e-4
,
betas
=
None
,
detector
=
None
,
dtglitchmin
=
20
*
86400
,
earth_ephem
=
None
,
detector
=
None
,
earth_ephem
=
None
,
sun_ephem
=
None
):
sun_ephem
=
None
):
"""
"""
Parameters
Parameters
label, outdir: str
label, outdir: str
...
@@ -370,8 +369,6 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -370,8 +369,6 @@ class MCMCGlitchSearch(BaseSearchClass):
Either a dictionary of distribution about which to distribute the
Either a dictionary of distribution about which to distribute the
initial walkers about, an array (from which the walkers will be
initial walkers about, an array (from which the walkers will be
scattered by scatter_val, or None in which case the prior is used.
scattered by scatter_val, or None in which case the prior is used.
nglitch: int
The number of glitches to allow
tref, tstart, tend: int
tref, tstart, tend: int
GPS seconds of the reference time, start time and end time
GPS seconds of the reference time, start time and end time
nsteps: list (m,)
nsteps: list (m,)
...
@@ -379,9 +376,6 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -379,9 +376,6 @@ class MCMCGlitchSearch(BaseSearchClass):
give the nburn and nprod of the
'
production
'
run, all entries
give the nburn and nprod of the
'
production
'
run, all entries
before are for iterative initialisation steps (usually just one)
before are for iterative initialisation steps (usually just one)
e.g. [1000, 1000, 500].
e.g. [1000, 1000, 500].
dtglitchmin: int
The minimum duration (in seconds) of a segment between two glitches
or a glitch and the start/end of the data
nwalkers, ntemps: int
nwalkers, ntemps: int
Number of walkers and temperatures
Number of walkers and temperatures
minCoverFreq, maxCoverFreq: float
minCoverFreq, maxCoverFreq: float
...
@@ -394,12 +388,14 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -394,12 +388,14 @@ class MCMCGlitchSearch(BaseSearchClass):
"""
"""
logging
.
info
(
(
'
Set-up MCMC search with {} glitches for model {} on
'
logging
.
info
(
'
data {}
'
)
.
format
(
self
.
nglitch
,
self
.
label
,
'
Set-up MCMC search for model {} on
data {}
'
.
format
(
self
.
sftlabel
))
self
.
label
,
self
.
sftlabel
))
if
os
.
path
.
isdir
(
outdir
)
is
False
:
if
os
.
path
.
isdir
(
outdir
)
is
False
:
os
.
mkdir
(
outdir
)
os
.
mkdir
(
outdir
)
self
.
pickle_path
=
'
{}/{}_saved_data.p
'
.
format
(
self
.
outdir
,
self
.
label
)
self
.
pickle_path
=
'
{}/{}_saved_data.p
'
.
format
(
self
.
outdir
,
self
.
label
)
self
.
theta_prior
[
'
tstart
'
]
=
self
.
tstart
self
.
theta_prior
[
'
tend
'
]
=
self
.
tend
self
.
unpack_input_theta
()
self
.
unpack_input_theta
()
self
.
ndim
=
len
(
self
.
theta_keys
)
self
.
ndim
=
len
(
self
.
theta_keys
)
self
.
sft_filepath
=
self
.
sftdir
+
'
/*_
'
+
self
.
sftlabel
+
"
*sft
"
self
.
sft_filepath
=
self
.
sftdir
+
'
/*_
'
+
self
.
sftlabel
+
"
*sft
"
...
@@ -415,53 +411,35 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -415,53 +411,35 @@ class MCMCGlitchSearch(BaseSearchClass):
def
inititate_search_object
(
self
):
def
inititate_search_object
(
self
):
logging
.
info
(
'
Setting up search object
'
)
logging
.
info
(
'
Setting up search object
'
)
self
.
search
=
SemiCoherentGlitchSearch
(
self
.
search
=
ComputeFstat
(
label
=
self
.
label
,
outdir
=
self
.
outdir
,
sftlabel
=
self
.
sftlabel
,
tref
=
self
.
tref
,
sftlabel
=
self
.
sftlabel
,
sftdir
=
self
.
sftdir
,
tref
=
self
.
tref
,
tstart
=
self
.
tstart
,
sftdir
=
self
.
sftdir
,
minCoverFreq
=
self
.
minCoverFreq
,
tend
=
self
.
tend
,
minCoverFreq
=
self
.
minCoverFreq
,
maxCoverFreq
=
self
.
maxCoverFreq
,
earth_ephem
=
self
.
earth_ephem
,
maxCoverFreq
=
self
.
maxCoverFreq
,
earth_ephem
=
self
.
earth_ephem
,
sun_ephem
=
self
.
sun_ephem
,
detector
=
self
.
detector
,
sun_ephem
=
self
.
sun_ephem
,
detector
=
self
.
detector
)
nglitch
=
self
.
nglitch
)
def
logp
(
self
,
theta_vals
,
theta_prior
,
theta_keys
,
search
):
def
logp
(
self
,
theta_vals
,
theta_prior
,
theta_keys
,
search
):
if
self
.
nglitch
>
1
:
H
=
[
self
.
generic_lnprior
(
**
theta_prior
[
key
])(
p
)
for
p
,
key
in
ts
=
[
self
.
tstart
]
+
theta_vals
[
-
self
.
nglitch
:]
+
[
self
.
tend
]
if
np
.
array_equal
(
ts
,
np
.
sort
(
ts
))
is
False
:
return
-
np
.
inf
if
any
(
np
.
diff
(
ts
)
<
self
.
dtglitchmin
):
return
-
np
.
inf
H
=
[
self
.
Generic_lnprior
(
**
theta_prior
[
key
])(
p
)
for
p
,
key
in
zip
(
theta_vals
,
theta_keys
)]
zip
(
theta_vals
,
theta_keys
)]
return
np
.
sum
(
H
)
return
np
.
sum
(
H
)
def
logl
(
self
,
theta
,
search
):
def
logl
(
self
,
theta
,
search
):
for
j
,
theta_i
in
enumerate
(
self
.
theta_idxs
):
for
j
,
theta_i
in
enumerate
(
self
.
theta_idxs
):
self
.
fixed_theta
[
theta_i
]
=
theta
[
j
]
self
.
fixed_theta
[
theta_i
]
=
theta
[
j
]
FS
=
search
.
compute
_nglitch_fsta
t
(
*
self
.
fixed_theta
)
FS
=
search
.
run_
compute
fstatistic_single_poin
t
(
*
self
.
fixed_theta
)
return
FS
return
FS
def
unpack_input_theta
(
self
):
def
unpack_input_theta
(
self
):
glitch_keys
=
[
'
delta_F0
'
,
'
delta_F1
'
,
'
tglitch
'
]
full_theta_keys
=
[
'
tstart
'
,
'
tend
'
,
'
F0
'
,
'
F1
'
,
'
F2
'
,
'
Alpha
'
,
full_glitch_keys
=
list
(
np
.
array
(
'
Delta
'
]
[[
gk
]
*
self
.
nglitch
for
gk
in
glitch_keys
]).
flatten
())
full_theta_keys
=
[
'
F0
'
,
'
F1
'
,
'
F2
'
,
'
Alpha
'
,
'
Delta
'
]
+
full_glitch_keys
full_theta_keys_copy
=
copy
.
copy
(
full_theta_keys
)
full_theta_keys_copy
=
copy
.
copy
(
full_theta_keys
)
glitch_symbols
=
[
'
$\delta f$
'
,
'
$\delta \dot{f}$
'
,
r
'
$t_{glitch}$
'
]
full_theta_symbols
=
[
'
_
'
,
'
_
'
,
'
$f$
'
,
'
$\dot{f}$
'
,
'
$\ddot{f}$
'
,
full_glitch_symbols
=
list
(
np
.
array
(
r
'
$\alpha$
'
,
r
'
$\delta$
'
]
[[
gs
]
*
self
.
nglitch
for
gs
in
glitch_symbols
]).
flatten
())
full_theta_symbols
=
([
'
$f$
'
,
'
$\dot{f}$
'
,
'
$\ddot{f}$
'
,
r
'
$\alpha$
'
,
r
'
$\delta$
'
]
+
full_glitch_symbols
)
self
.
theta_keys
=
[]
self
.
theta_keys
=
[]
fixed_theta_dict
=
{}
fixed_theta_dict
=
{}
for
key
,
val
in
self
.
theta_prior
.
iteritems
():
for
key
,
val
in
self
.
theta_prior
.
iteritems
():
if
type
(
val
)
is
dict
:
if
type
(
val
)
is
dict
:
fixed_theta_dict
[
key
]
=
0
fixed_theta_dict
[
key
]
=
0
if
key
in
glitch_keys
:
for
i
in
range
(
self
.
nglitch
):
self
.
theta_keys
.
append
(
key
)
else
:
self
.
theta_keys
.
append
(
key
)
self
.
theta_keys
.
append
(
key
)
elif
type
(
val
)
in
[
float
,
int
,
np
.
float64
]:
elif
type
(
val
)
in
[
float
,
int
,
np
.
float64
]:
fixed_theta_dict
[
key
]
=
val
fixed_theta_dict
[
key
]
=
val
...
@@ -469,10 +447,6 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -469,10 +447,6 @@ class MCMCGlitchSearch(BaseSearchClass):
raise
ValueError
(
raise
ValueError
(
'
Type {} of {} in theta not recognised
'
.
format
(
'
Type {} of {} in theta not recognised
'
.
format
(
type
(
val
),
key
))
type
(
val
),
key
))
if
key
in
glitch_keys
:
for
i
in
range
(
self
.
nglitch
):
full_theta_keys_copy
.
pop
(
full_theta_keys_copy
.
index
(
key
))
else
:
full_theta_keys_copy
.
pop
(
full_theta_keys_copy
.
index
(
key
))
full_theta_keys_copy
.
pop
(
full_theta_keys_copy
.
index
(
key
))
if
len
(
full_theta_keys_copy
)
>
0
:
if
len
(
full_theta_keys_copy
)
>
0
:
...
@@ -489,13 +463,6 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -489,13 +463,6 @@ class MCMCGlitchSearch(BaseSearchClass):
self
.
theta_symbols
=
[
self
.
theta_symbols
[
i
]
for
i
in
idxs
]
self
.
theta_symbols
=
[
self
.
theta_symbols
[
i
]
for
i
in
idxs
]
self
.
theta_keys
=
[
self
.
theta_keys
[
i
]
for
i
in
idxs
]
self
.
theta_keys
=
[
self
.
theta_keys
[
i
]
for
i
in
idxs
]
# Correct for number of glitches in the idxs
self
.
theta_idxs
=
np
.
array
(
self
.
theta_idxs
)
while
np
.
sum
(
self
.
theta_idxs
[:
-
1
]
==
self
.
theta_idxs
[
1
:])
>
0
:
for
i
,
idx
in
enumerate
(
self
.
theta_idxs
):
if
idx
in
self
.
theta_idxs
[:
i
]:
self
.
theta_idxs
[
i
]
+=
1
def
check_initial_points
(
self
,
p0
):
def
check_initial_points
(
self
,
p0
):
initial_priors
=
np
.
array
([
initial_priors
=
np
.
array
([
self
.
logp
(
p
,
self
.
theta_prior
,
self
.
theta_keys
,
self
.
search
)
self
.
logp
(
p
,
self
.
theta_prior
,
self
.
theta_keys
,
self
.
search
)
...
@@ -525,7 +492,8 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -525,7 +492,8 @@ class MCMCGlitchSearch(BaseSearchClass):
logpargs
=
(
self
.
theta_prior
,
self
.
theta_keys
,
self
.
search
),
logpargs
=
(
self
.
theta_prior
,
self
.
theta_keys
,
self
.
search
),
loglargs
=
(
self
.
search
,),
betas
=
self
.
betas
)
loglargs
=
(
self
.
search
,),
betas
=
self
.
betas
)
p0
=
self
.
GenerateInitial
()
p0
=
self
.
generate_initial_p0
()
p0
=
self
.
apply_corrections_to_p0
(
p0
)
self
.
check_initial_points
(
p0
)
self
.
check_initial_points
(
p0
)
ninit_steps
=
len
(
self
.
nsteps
)
-
2
ninit_steps
=
len
(
self
.
nsteps
)
-
2
...
@@ -534,11 +502,12 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -534,11 +502,12 @@ class MCMCGlitchSearch(BaseSearchClass):
j
,
ninit_steps
,
n
))
j
,
ninit_steps
,
n
))
sampler
.
run_mcmc
(
p0
,
n
)
sampler
.
run_mcmc
(
p0
,
n
)
fig
,
axes
=
self
.
P
lot
W
alkers
(
sampler
,
symbols
=
self
.
theta_symbols
)
fig
,
axes
=
self
.
p
lot
_w
alkers
(
sampler
,
symbols
=
self
.
theta_symbols
)
fig
.
savefig
(
'
{}/{}_init_{}_walkers.png
'
.
format
(
fig
.
savefig
(
'
{}/{}_init_{}_walkers.png
'
.
format
(
self
.
outdir
,
self
.
label
,
j
))
self
.
outdir
,
self
.
label
,
j
))
p0
=
self
.
get_new_p0
(
sampler
,
scatter_val
=
self
.
scatter_val
)
p0
=
self
.
get_new_p0
(
sampler
,
scatter_val
=
self
.
scatter_val
)
p0
=
self
.
apply_corrections_to_p0
(
p0
)
self
.
check_initial_points
(
p0
)
self
.
check_initial_points
(
p0
)
sampler
.
reset
()
sampler
.
reset
()
...
@@ -548,7 +517,7 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -548,7 +517,7 @@ class MCMCGlitchSearch(BaseSearchClass):
nburn
+
nprod
))
nburn
+
nprod
))
sampler
.
run_mcmc
(
p0
,
nburn
+
nprod
)
sampler
.
run_mcmc
(
p0
,
nburn
+
nprod
)
fig
,
axes
=
self
.
P
lot
W
alkers
(
sampler
,
symbols
=
self
.
theta_symbols
)
fig
,
axes
=
self
.
p
lot
_w
alkers
(
sampler
,
symbols
=
self
.
theta_symbols
)
fig
.
savefig
(
'
{}/{}_walkers.png
'
.
format
(
self
.
outdir
,
self
.
label
))
fig
.
savefig
(
'
{}/{}_walkers.png
'
.
format
(
self
.
outdir
,
self
.
label
))
samples
=
sampler
.
chain
[
0
,
:,
nburn
:,
:].
reshape
((
-
1
,
self
.
ndim
))
samples
=
sampler
.
chain
[
0
,
:,
nburn
:,
:].
reshape
((
-
1
,
self
.
ndim
))
...
@@ -622,14 +591,14 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -622,14 +591,14 @@ class MCMCGlitchSearch(BaseSearchClass):
ax
=
axes
[
i
][
i
]
ax
=
axes
[
i
][
i
]
xlim
=
ax
.
get_xlim
()
xlim
=
ax
.
get_xlim
()
s
=
samples
[:,
i
]
s
=
samples
[:,
i
]
prior
=
self
.
G
eneric_lnprior
(
**
self
.
theta_prior
[
key
])
prior
=
self
.
g
eneric_lnprior
(
**
self
.
theta_prior
[
key
])
x
=
np
.
linspace
(
s
.
min
(),
s
.
max
(),
100
)
x
=
np
.
linspace
(
s
.
min
(),
s
.
max
(),
100
)
ax2
=
ax
.
twinx
()
ax2
=
ax
.
twinx
()
ax2
.
get_yaxis
().
set_visible
(
False
)
ax2
.
get_yaxis
().
set_visible
(
False
)
ax2
.
plot
(
x
,
[
prior
(
xi
)
for
xi
in
x
],
'
-r
'
)
ax2
.
plot
(
x
,
[
prior
(
xi
)
for
xi
in
x
],
'
-r
'
)
ax
.
set_xlim
(
xlim
)
ax
.
set_xlim
(
xlim
)
def
G
eneric_lnprior
(
self
,
**
kwargs
):
def
g
eneric_lnprior
(
self
,
**
kwargs
):
"""
Return a lambda function of the pdf
"""
Return a lambda function of the pdf
Parameters
Parameters
...
@@ -679,7 +648,7 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -679,7 +648,7 @@ class MCMCGlitchSearch(BaseSearchClass):
logging
.
info
(
"
kwargs:
"
,
kwargs
)
logging
.
info
(
"
kwargs:
"
,
kwargs
)
raise
ValueError
(
"
Print unrecognise distribution
"
)
raise
ValueError
(
"
Print unrecognise distribution
"
)
def
G
enerate
RV
(
self
,
**
kwargs
):
def
g
enerate
_rv
(
self
,
**
kwargs
):
dist_type
=
kwargs
.
pop
(
'
type
'
)
dist_type
=
kwargs
.
pop
(
'
type
'
)
if
dist_type
==
"
unif
"
:
if
dist_type
==
"
unif
"
:
return
np
.
random
.
uniform
(
low
=
kwargs
[
'
lower
'
],
high
=
kwargs
[
'
upper
'
])
return
np
.
random
.
uniform
(
low
=
kwargs
[
'
lower
'
],
high
=
kwargs
[
'
upper
'
])
...
@@ -694,7 +663,7 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -694,7 +663,7 @@ class MCMCGlitchSearch(BaseSearchClass):
else
:
else
:
raise
ValueError
(
"
dist_type {} unknown
"
.
format
(
dist_type
))
raise
ValueError
(
"
dist_type {} unknown
"
.
format
(
dist_type
))
def
P
lot
W
alkers
(
self
,
sampler
,
symbols
=
None
,
alpha
=
0.4
,
color
=
"
k
"
,
temp
=
0
,
def
p
lot
_w
alkers
(
self
,
sampler
,
symbols
=
None
,
alpha
=
0.4
,
color
=
"
k
"
,
temp
=
0
,
start
=
None
,
stop
=
None
,
draw_vline
=
None
):
start
=
None
,
stop
=
None
,
draw_vline
=
None
):
"""
Plot all the chains from a sampler
"""
"""
Plot all the chains from a sampler
"""
...
@@ -725,38 +694,35 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -725,38 +694,35 @@ class MCMCGlitchSearch(BaseSearchClass):
return
fig
,
axes
return
fig
,
axes
def
_generate_scattered_p0
(
self
,
p
):
def
apply_corrections_to_p0
(
self
,
p0
):
"""
Apply any correction to the initial p0 values
"""
return
p0
def
generate_scattered_p0
(
self
,
p
):
"""
Generate a set of p0s scattered about p
"""
"""
Generate a set of p0s scattered about p
"""
p0
=
[[
p
+
scatter_val
*
p
*
np
.
random
.
randn
(
self
.
ndim
)
p0
=
[[
p
+
self
.
scatter_val
*
p
*
np
.
random
.
randn
(
self
.
ndim
)
for
i
in
xrange
(
self
.
nwalkers
)]
for
i
in
xrange
(
self
.
nwalkers
)]
for
j
in
xrange
(
self
.
ntemps
)]
for
j
in
xrange
(
self
.
ntemps
)]
return
p0
return
p0
def
_sort_p0_times
(
self
,
p0
):
def
generate_initial_p0
(
self
):
p0
=
np
.
array
(
p0
)
p0
[:,
:,
-
self
.
nglitch
:]
=
np
.
sort
(
p0
[:,
:,
-
self
.
nglitch
:],
axis
=
2
)
return
p0
def
GenerateInitial
(
self
):
"""
Generate a set of init vals for the walkers
"""
"""
Generate a set of init vals for the walkers
"""
if
type
(
self
.
theta_initial
)
==
dict
:
if
type
(
self
.
theta_initial
)
==
dict
:
p0
=
[[[
self
.
G
enerate
RV
(
**
self
.
theta_initial
[
key
])
p0
=
[[[
self
.
g
enerate
_rv
(
**
self
.
theta_initial
[
key
])
for
key
in
self
.
theta_keys
]
for
key
in
self
.
theta_keys
]
for
i
in
range
(
self
.
nwalkers
)]
for
i
in
range
(
self
.
nwalkers
)]
for
j
in
range
(
self
.
ntemps
)]
for
j
in
range
(
self
.
ntemps
)]
elif
self
.
theta_initial
is
None
:
elif
self
.
theta_initial
is
None
:
p0
=
[[[
self
.
G
enerate
RV
(
**
self
.
theta_prior
[
key
])
p0
=
[[[
self
.
g
enerate
_rv
(
**
self
.
theta_prior
[
key
])
for
key
in
self
.
theta_keys
]
for
key
in
self
.
theta_keys
]
for
i
in
range
(
self
.
nwalkers
)]
for
i
in
range
(
self
.
nwalkers
)]
for
j
in
range
(
self
.
ntemps
)]
for
j
in
range
(
self
.
ntemps
)]
elif
len
(
self
.
theta_initial
)
==
self
.
ndim
:
elif
len
(
self
.
theta_initial
)
==
self
.
ndim
:
p0
=
self
.
_
generate_scattered_p0
(
self
.
theta_initial
)
p0
=
self
.
generate_scattered_p0
(
self
.
theta_initial
)
else
:
else
:
raise
ValueError
(
'
theta_initial not understood
'
)
raise
ValueError
(
'
theta_initial not understood
'
)
if
self
.
nglitch
>
1
:
p0
=
self
.
_sort_p0_times
(
p0
)
return
p0
return
p0
def
get_new_p0
(
self
,
sampler
,
scatter_val
=
1e-3
):
def
get_new_p0
(
self
,
sampler
,
scatter_val
=
1e-3
):
...
@@ -780,8 +746,6 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -780,8 +746,6 @@ class MCMCGlitchSearch(BaseSearchClass):
p
=
pF
[
np
.
nanargmax
(
lnp
)]
p
=
pF
[
np
.
nanargmax
(
lnp
)]
p0
=
self
.
_generate_scattered_p0
(
p
)
p0
=
self
.
_generate_scattered_p0
(
p
)
if
self
.
nglitch
>
1
:
p0
=
self
.
_sort_p0_times
(
p0
)
return
p0
return
p0
def
get_save_data_dictionary
(
self
):
def
get_save_data_dictionary
(
self
):
...
@@ -923,6 +887,164 @@ class MCMCGlitchSearch(BaseSearchClass):
...
@@ -923,6 +887,164 @@ class MCMCGlitchSearch(BaseSearchClass):
k
,
d
[
k
],
d
[
k
+
'
_std
'
]))
k
,
d
[
k
],
d
[
k
+
'
_std
'
]))
class
MCMCGlitchSearch
(
MCMCSearch
):
"""
MCMC search using the SemiCoherentGlitchSearch
"""
@initializer
def
__init__
(
self
,
label
,
outdir
,
sftlabel
,
sftdir
,
theta_prior
,
tref
,
tstart
,
tend
,
nsteps
=
[
100
,
100
,
100
],
nwalkers
=
100
,
ntemps
=
1
,
nglitch
=
0
,
theta_initial
=
None
,
minCoverFreq
=
None
,
maxCoverFreq
=
None
,
scatter_val
=
1e-4
,
betas
=
None
,
detector
=
None
,
dtglitchmin
=
20
*
86400
,
earth_ephem
=
None
,
sun_ephem
=
None
):
"""
Parameters
label, outdir: str
A label and directory to read/write data from/to
sftlabel, sftdir: str
A label and directory in which to find the relevant sft file
theta_prior: dict
Dictionary of priors and fixed values for the search parameters.
For each parameters (key of the dict), if it is to be held fixed
the value should be the constant float, if it is be searched, the
value should be a dictionary of the prior.
theta_initial: dict, array, (None)
Either a dictionary of distribution about which to distribute the
initial walkers about, an array (from which the walkers will be
scattered by scatter_val, or None in which case the prior is used.
nglitch: int
The number of glitches to allow
tref, tstart, tend: int
GPS seconds of the reference time, start time and end time
nsteps: list (m,)
List specifying the number of steps to take, the last two entries
give the nburn and nprod of the
'
production
'
run, all entries
before are for iterative initialisation steps (usually just one)
e.g. [1000, 1000, 500].
dtglitchmin: int
The minimum duration (in seconds) of a segment between two glitches
or a glitch and the start/end of the data
nwalkers, ntemps: int
Number of walkers and temperatures
minCoverFreq, maxCoverFreq: float
Minimum and maximum instantaneous frequency which will be covered
over the SFT time span as passed to CreateFstatInput
earth_ephem, sun_ephem: str
Paths of the two files containing positions of Earth and Sun,
respectively at evenly spaced times, as passed to CreateFstatInput
If None defaults defined in BaseSearchClass will be used
"""
logging
.
info
((
'
Set-up MCMC glitch search with {} glitches for model {}
'
'
on data {}
'
).
format
(
self
.
nglitch
,
self
.
label
,
self
.
sftlabel
))
if
os
.
path
.
isdir
(
outdir
)
is
False
:
os
.
mkdir
(
outdir
)
self
.
pickle_path
=
'
{}/{}_saved_data.p
'
.
format
(
self
.
outdir
,
self
.
label
)
self
.
unpack_input_theta
()
self
.
ndim
=
len
(
self
.
theta_keys
)
self
.
sft_filepath
=
self
.
sftdir
+
'
/*_
'
+
self
.
sftlabel
+
"
*sft
"
if
earth_ephem
is
None
:
self
.
earth_ephem
=
self
.
earth_ephem_default
if
sun_ephem
is
None
:
self
.
sun_ephem
=
self
.
sun_ephem_default
if
args
.
clean
and
os
.
path
.
isfile
(
self
.
pickle_path
):
os
.
rename
(
self
.
pickle_path
,
self
.
pickle_path
+
"
.old
"
)
self
.
old_data_is_okay_to_use
=
self
.
check_old_data_is_okay_to_use
()
def
inititate_search_object
(
self
):
logging
.
info
(
'
Setting up search object
'
)
self
.
search
=
SemiCoherentGlitchSearch
(
label
=
self
.
label
,
outdir
=
self
.
outdir
,
sftlabel
=
self
.
sftlabel
,
sftdir
=
self
.
sftdir
,
tref
=
self
.
tref
,
tstart
=
self
.
tstart
,
tend
=
self
.
tend
,
minCoverFreq
=
self
.
minCoverFreq
,
maxCoverFreq
=
self
.
maxCoverFreq
,
earth_ephem
=
self
.
earth_ephem
,
sun_ephem
=
self
.
sun_ephem
,
detector
=
self
.
detector
,
nglitch
=
self
.
nglitch
)
def
logp
(
self
,
theta_vals
,
theta_prior
,
theta_keys
,
search
):
if
self
.
nglitch
>
1
:
ts
=
[
self
.
tstart
]
+
theta_vals
[
-
self
.
nglitch
:]
+
[
self
.
tend
]
if
np
.
array_equal
(
ts
,
np
.
sort
(
ts
))
is
False
:
return
-
np
.
inf
if
any
(
np
.
diff
(
ts
)
<
self
.
dtglitchmin
):
return
-
np
.
inf
H
=
[
self
.
generic_lnprior
(
**
theta_prior
[
key
])(
p
)
for
p
,
key
in
zip
(
theta_vals
,
theta_keys
)]
return
np
.
sum
(
H
)
def
logl
(
self
,
theta
,
search
):
for
j
,
theta_i
in
enumerate
(
self
.
theta_idxs
):
self
.
fixed_theta
[
theta_i
]
=
theta
[
j
]
FS
=
search
.
compute_nglitch_fstat
(
*
self
.
fixed_theta
)
return
FS
def
unpack_input_theta
(
self
):
glitch_keys
=
[
'
delta_F0
'
,
'
delta_F1
'
,
'
tglitch
'
]
full_glitch_keys
=
list
(
np
.
array
(
[[
gk
]
*
self
.
nglitch
for
gk
in
glitch_keys
]).
flatten
())
full_theta_keys
=
[
'
F0
'
,
'
F1
'
,
'
F2
'
,
'
Alpha
'
,
'
Delta
'
]
+
full_glitch_keys
full_theta_keys_copy
=
copy
.
copy
(
full_theta_keys
)
glitch_symbols
=
[
'
$\delta f$
'
,
'
$\delta \dot{f}$
'
,
r
'
$t_{glitch}$
'
]
full_glitch_symbols
=
list
(
np
.
array
(
[[
gs
]
*
self
.
nglitch
for
gs
in
glitch_symbols
]).
flatten
())
full_theta_symbols
=
([
'
$f$
'
,
'
$\dot{f}$
'
,
'
$\ddot{f}$
'
,
r
'
$\alpha$
'
,
r
'
$\delta$
'
]
+
full_glitch_symbols
)
self
.
theta_keys
=
[]
fixed_theta_dict
=
{}
for
key
,
val
in
self
.
theta_prior
.
iteritems
():
if
type
(
val
)
is
dict
:
fixed_theta_dict
[
key
]
=
0
if
key
in
glitch_keys
:
for
i
in
range
(
self
.
nglitch
):
self
.
theta_keys
.
append
(
key
)
else
:
self
.
theta_keys
.
append
(
key
)
elif
type
(
val
)
in
[
float
,
int
,
np
.
float64
]:
fixed_theta_dict
[
key
]
=
val
else
:
raise
ValueError
(
'
Type {} of {} in theta not recognised
'
.
format
(
type
(
val
),
key
))
if
key
in
glitch_keys
:
for
i
in
range
(
self
.
nglitch
):
full_theta_keys_copy
.
pop
(
full_theta_keys_copy
.
index
(
key
))
else
:
full_theta_keys_copy
.
pop
(
full_theta_keys_copy
.
index
(
key
))
if
len
(
full_theta_keys_copy
)
>
0
:
raise
ValueError
((
'
Input dictionary `theta` is missing the
'
'
following keys: {}
'
).
format
(
full_theta_keys_copy
))
self
.
fixed_theta
=
[
fixed_theta_dict
[
key
]
for
key
in
full_theta_keys
]
self
.
theta_idxs
=
[
full_theta_keys
.
index
(
k
)
for
k
in
self
.
theta_keys
]
self
.
theta_symbols
=
[
full_theta_symbols
[
i
]
for
i
in
self
.
theta_idxs
]
idxs
=
np
.
argsort
(
self
.
theta_idxs
)
self
.
theta_idxs
=
[
self
.
theta_idxs
[
i
]
for
i
in
idxs
]
self
.
theta_symbols
=
[
self
.
theta_symbols
[
i
]
for
i
in
idxs
]
self
.
theta_keys
=
[
self
.
theta_keys
[
i
]
for
i
in
idxs
]
# Correct for number of glitches in the idxs
self
.
theta_idxs
=
np
.
array
(
self
.
theta_idxs
)
while
np
.
sum
(
self
.
theta_idxs
[:
-
1
]
==
self
.
theta_idxs
[
1
:])
>
0
:
for
i
,
idx
in
enumerate
(
self
.
theta_idxs
):
if
idx
in
self
.
theta_idxs
[:
i
]:
self
.
theta_idxs
[
i
]
+=
1
def
apply_corrections_to_p0
(
self
,
p0
):
p0
=
np
.
array
(
p0
)
if
self
.
nglitch
>
1
:
p0
[:,
:,
-
self
.
nglitch
:]
=
np
.
sort
(
p0
[:,
:,
-
self
.
nglitch
:],
axis
=
2
)
return
p0
class
GridGlitchSearch
(
BaseSearchClass
):
class
GridGlitchSearch
(
BaseSearchClass
):
"""
Gridded search using the SemiCoherentGlitchSearch
"""
"""
Gridded search using the SemiCoherentGlitchSearch
"""
@initializer
@initializer
...
...
This diff is collapsed.
Click to expand it.
tests/tests.py
+
6
−
6
View file @
8435f54d
...
@@ -137,7 +137,7 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase):
...
@@ -137,7 +137,7 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase):
self
.
assertTrue
(
np
.
abs
((
FS
-
predicted_FS
))
/
predicted_FS
<
0.3
)
self
.
assertTrue
(
np
.
abs
((
FS
-
predicted_FS
))
/
predicted_FS
<
0.3
)
class
TestMCMC
Glitch
Search
(
unittest
.
TestCase
):
class
TestMCMCSearch
(
unittest
.
TestCase
):
label
=
"
MCMCTest
"
label
=
"
MCMCTest
"
outdir
=
'
TestData
'
outdir
=
'
TestData
'
...
@@ -165,13 +165,12 @@ class TestMCMCGlitchSearch(unittest.TestCase):
...
@@ -165,13 +165,12 @@ class TestMCMCGlitchSearch(unittest.TestCase):
Writer
.
make_data
()
Writer
.
make_data
()
predicted_FS
=
Writer
.
predict_fstat
()
predicted_FS
=
Writer
.
predict_fstat
()
theta
=
{
'
delta_F0
'
:
0
,
'
delta_F1
'
:
0
,
'
tglitch
'
:
tend
,
theta
=
{
'
F0
'
:
{
'
type
'
:
'
norm
'
,
'
loc
'
:
F0
,
'
scale
'
:
np
.
abs
(
1e-9
*
F0
)},
'
F0
'
:
{
'
type
'
:
'
norm
'
,
'
loc
'
:
F0
,
'
scale
'
:
np
.
abs
(
1e-9
*
F0
)},
'
F1
'
:
{
'
type
'
:
'
norm
'
,
'
loc
'
:
F1
,
'
scale
'
:
np
.
abs
(
1e-9
*
F1
)},
'
F1
'
:
{
'
type
'
:
'
norm
'
,
'
loc
'
:
F1
,
'
scale
'
:
np
.
abs
(
1e-9
*
F1
)},
'
F2
'
:
F2
,
'
Alpha
'
:
Alpha
,
'
Delta
'
:
Delta
}
'
F2
'
:
F2
,
'
Alpha
'
:
Alpha
,
'
Delta
'
:
Delta
}
search
=
pyfstat
.
MCMC
Glitch
Search
(
search
=
pyfstat
.
MCMCSearch
(
label
=
self
.
label
,
outdir
=
self
.
outdir
,
theta
=
theta
,
tref
=
tref
,
label
=
self
.
label
,
outdir
=
self
.
outdir
,
theta
_prior
=
theta
,
tref
=
tref
,
sftlabel
=
self
.
label
,
sftdir
=
self
.
outdir
,
sftlabel
=
self
.
label
,
sftdir
=
self
.
outdir
,
tstart
=
tstart
,
tend
=
tend
,
nsteps
=
[
100
,
100
],
nwalkers
=
100
,
tstart
=
tstart
,
tend
=
tend
,
nsteps
=
[
100
,
100
],
nwalkers
=
100
,
ntemps
=
1
)
ntemps
=
1
)
...
@@ -181,7 +180,8 @@ class TestMCMCGlitchSearch(unittest.TestCase):
...
@@ -181,7 +180,8 @@ class TestMCMCGlitchSearch(unittest.TestCase):
print
(
'
Predicted twoF is {} while recovered is {}
'
.
format
(
print
(
'
Predicted twoF is {} while recovered is {}
'
.
format
(
predicted_FS
,
FS
))
predicted_FS
,
FS
))
self
.
assertTrue
(
np
.
abs
((
FS
-
predicted_FS
))
/
predicted_FS
<
0.3
)
self
.
assertTrue
(
FS
>
predicted_FS
or
np
.
abs
((
FS
-
predicted_FS
))
/
predicted_FS
<
0.3
)
if
__name__
==
'
__main__
'
:
if
__name__
==
'
__main__
'
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment