Skip to content
Snippets Groups Projects
Commit 88856d78 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adds MCMCFollowUpSearch

This adds a class to perform follow ups from semi-coherent to
fully-coherent using an MCMC search.

- Also changes the logic of the MCMC search to check for old data only
  in the run method.
- The nsteps argument exists for the MCMCFOllowUP, but is unused.
- Some work on the plot_walkers to allow for overplotting. It may be
  useful to use this for the usual calls as well so that we don't have
  an initial and walkers plot?
parent 17f9dffc
Branches
Tags
No related merge requests found
docs/img/follow_up_walkers.png

104 KiB

import pyfstat
F0 = 30.0
F1 = -1e-10
F2 = 0
Alpha = 5e-3
Delta = 6e-2
tref = 362750407.0
tstart = 1000000000
duration = 100*86400
tend = tstart + duration
theta_prior = {'F0': {'type': 'unif', 'lower': F0*(1-1e-6), 'upper': F0*(1+1e-5)},
'F1': {'type': 'unif', 'lower': F1*(1+1e-2), 'upper': F1*(1-1e-2)},
'F2': F2,
'Alpha': Alpha,
'Delta': Delta
}
ntemps = 1
log10temperature_min = -1
nwalkers = 100
run_setup = [(500, 50), (500, 25), (100, 1, False),
((100, 100), 1, True)]
mcmc = pyfstat.MCMCFollowUpSearch(
label='follow_up', outdir='data',
sftfilepath='data/*basic*sft', theta_prior=theta_prior, tref=tref,
minStartTime=tstart, maxStartTime=tend, nwalkers=nwalkers,
ntemps=ntemps, log10temperature_min=log10temperature_min)
mcmc.run(run_setup)
mcmc.plot_corner(add_prior=True)
mcmc.print_summary()
......@@ -477,15 +477,68 @@ class SemiCoherentSearch(BaseSearchClass, ComputeFstat):
self.tboundaries = np.linspace(self.minStartTime, self.maxStartTime,
self.nsegs+1)
def compute_nseg_fstat(self, F0, F1, F2, Alpha, Delta):
""" Returns the semi-coherent summed twoF """
def run_semi_coherent_computefstatistic_single_point(
self, F0, F1, F2, Alpha, Delta, asini=None,
period=None, ecc=None, tp=None, argp=None):
""" Returns twoF or ln(BSGL) semi-coherently at a single point """
twoFvals = [self.run_computefstatistic_single_point(
self.tboundaries[i], self.tboundaries[i+1], F0, F1, F2, Alpha,
Delta)
for i in range(self.nsegs)]
self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0])
self.PulsarDopplerParams.Alpha = Alpha
self.PulsarDopplerParams.Delta = Delta
if self.binary:
self.PulsarDopplerParams.asini = asini
self.PulsarDopplerParams.period = period
self.PulsarDopplerParams.ecc = ecc
self.PulsarDopplerParams.tp = tp
self.PulsarDopplerParams.argp = argp
lalpulsar.ComputeFstat(self.FstatResults,
self.FstatInput,
self.PulsarDopplerParams,
1,
self.whatToCompute
)
if self.transient is False:
if self.BSGL is False:
return self.FstatResults.twoF[0]
twoF = np.float(self.FstatResults.twoF[0])
self.twoFX[0] = self.FstatResults.twoFPerDet(0)
self.twoFX[1] = self.FstatResults.twoFPerDet(1)
log10_BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX,
self.BSGLSetup)
return log10_BSGL/np.log10(np.exp(1))
return np.sum(twoFvals)
detStat = 0
for tstart, tend in zip(self.tboundaries[:-1], self.tboundaries[1:]):
self.windowRange.t0 = int(tstart) # TYPE UINT4
self.windowRange.tau = int(tend - tstart) # TYPE UINT4
FS = lalpulsar.ComputeTransientFstatMap(
self.FstatResults.multiFatoms[0], self.windowRange, False)
if self.BSGL is False:
detStat += 2*FS.F_mn.data[0][0]
continue
FstatResults_single = copy.copy(self.FstatResults)
FstatResults_single.lenth = 1
FstatResults_single.data = self.FstatResults.multiFatoms[0].data[0]
FS0 = lalpulsar.ComputeTransientFstatMap(
FstatResults_single.multiFatoms[0], self.windowRange, False)
FstatResults_single.data = self.FstatResults.multiFatoms[0].data[1]
FS1 = lalpulsar.ComputeTransientFstatMap(
FstatResults_single.multiFatoms[0], self.windowRange, False)
self.twoFX[0] = 2*FS0.F_mn.data[0][0]
self.twoFX[1] = 2*FS1.F_mn.data[0][0]
log10_BSGL = lalpulsar.ComputeBSGL(
2*FS.F_mn.data[0][0], self.twoFX, self.BSGLSetup)
detStat += log10_BSGL/np.log10(np.exp(1))
return detStat
class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
......@@ -669,7 +722,6 @@ class MCMCSearch(BaseSearchClass):
if args.clean and os.path.isfile(self.pickle_path):
os.rename(self.pickle_path, self.pickle_path+".old")
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
self.log_input()
def log_input(self):
......@@ -789,6 +841,7 @@ class MCMCSearch(BaseSearchClass):
def run(self, proposal_scale_factor=2):
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
if self.old_data_is_okay_to_use is True:
logging.warning('Using saved data from {}'.format(
self.pickle_path))
......@@ -813,7 +866,7 @@ class MCMCSearch(BaseSearchClass):
ninit_steps = len(self.nsteps) - 2
for j, n in enumerate(self.nsteps[:-2]):
logging.info('Running {}/{} initialisation with {} steps'.format(
j+1, ninit_steps, n))
j, ninit_steps, n))
sampler = self.run_sampler_with_progress_bar(sampler, n, p0)
logging.info("Mean acceptance fraction: {}"
.format(np.mean(sampler.acceptance_fraction, axis=1)))
......@@ -1083,7 +1136,8 @@ class MCMCSearch(BaseSearchClass):
raise ValueError("dist_type {} unknown".format(dist_type))
def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
lw=0.1, burnin_idx=None, add_det_stat_burnin=False):
lw=0.1, burnin_idx=None, add_det_stat_burnin=False,
fig=None, axes=None, xoffset=0, plot_det_stat=True):
""" Plot all the chains from a sampler """
shape = sampler.chain.shape
......@@ -1100,6 +1154,7 @@ class MCMCSearch(BaseSearchClass):
chain = sampler.chain[temp, :, :, :]
with plt.style.context(('classic')):
if fig is None and axes is None:
fig = plt.figure(figsize=(8, 4*ndim))
ax = fig.add_subplot(ndim+1, 1, 1)
axes = [ax] + [fig.add_subplot(ndim+1, 1, i, sharex=ax)
......@@ -1111,10 +1166,11 @@ class MCMCSearch(BaseSearchClass):
axes[i].ticklabel_format(useOffset=False, axis='y')
cs = chain[:, :, i].T
if burnin_idx:
axes[i].plot(idxs[:burnin_idx], cs[:burnin_idx],
color="r", alpha=alpha, lw=lw)
axes[i].plot(idxs[burnin_idx:], cs[burnin_idx:], color="k",
alpha=alpha, lw=lw)
axes[i].plot(xoffset+idxs[:burnin_idx],
cs[:burnin_idx], color="r", alpha=alpha,
lw=lw)
axes[i].plot(xoffset+idxs[burnin_idx:], cs[burnin_idx:],
color="k", alpha=alpha, lw=lw)
if symbols:
axes[i].set_ylabel(symbols[i])
else:
......@@ -1128,19 +1184,30 @@ class MCMCSearch(BaseSearchClass):
if symbols:
axes[0].set_ylabel(symbols[0])
if len(axes) == ndim:
axes.append(fig.add_subplot(ndim+1, 1, ndim+1))
if plot_det_stat:
lnl = sampler.lnlikelihood[temp, :, :]
if burnin_idx and add_det_stat_burnin:
vals = lnl[:, :burnin_idx].flatten()
axes[-1].hist(vals[~np.isnan(vals)], bins=50, histtype='step',
color='r')
vals = lnl[:, burnin_idx:].flatten()
axes[-1].hist(vals[~np.isnan(vals)], bins=50, histtype='step',
color='k')
burn_in_vals = lnl[:, :burnin_idx].flatten()
axes[-1].hist(burn_in_vals[~np.isnan(burn_in_vals)], bins=50,
histtype='step', color='r')
else:
burn_in_vals = []
prod_vals = lnl[:, burnin_idx:].flatten()
axes[-1].hist(prod_vals[~np.isnan(prod_vals)], bins=50,
histtype='step', color='k')
if self.BSGL:
axes[-1].set_xlabel(r'$\mathcal{B}_\mathrm{S/GL}$')
else:
axes[-1].set_xlabel(r'$2\mathcal{F}$')
combined_vals = np.append(burn_in_vals, prod_vals)
if len(combined_vals) > 0:
minv = np.min(combined_vals)
maxv = np.max(combined_vals)
Range = abs(maxv-minv)
axes[-1].set_xlim(minv-0.1*Range, maxv+0.1*Range)
return fig, axes
......@@ -1461,7 +1528,6 @@ class MCMCSearch(BaseSearchClass):
def compute_evidence(self):
""" Computes the evidence/marginal likelihood for the model """
fburnin = float(self.nsteps[-2])/np.sum(self.nsteps[-2:])
print fburnin
lnev, lnev_err = self.sampler.thermodynamic_integration_log_evidence(
fburnin=fburnin)
......@@ -1794,7 +1860,6 @@ class MCMCSemiCoherentSearch(MCMCSearch):
if args.clean and os.path.isfile(self.pickle_path):
os.rename(self.pickle_path, self.pickle_path+".old")
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
self.log_input()
def inititate_search_object(self):
......@@ -1815,10 +1880,107 @@ class MCMCSemiCoherentSearch(MCMCSearch):
def logl(self, theta, search):
for j, theta_i in enumerate(self.theta_idxs):
self.fixed_theta[theta_i] = theta[j]
FS = search.compute_nseg_fstat(*self.fixed_theta)
FS = search.run_semi_coherent_computefstatistic_single_point(
*self.fixed_theta)
return FS
class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
""" A follow up procudure increasing the coherence time in a zoom """
def get_save_data_dictionary(self):
d = dict(nwalkers=self.nwalkers, ntemps=self.ntemps,
theta_keys=self.theta_keys, theta_prior=self.theta_prior,
scatter_val=self.scatter_val,
log10temperature_min=self.log10temperature_min,
BSGL=self.BSGL, run_setup=self.run_setup)
return d
def run(self, run_setup, proposal_scale_factor=2):
""" Run the follow-up with the given run_setup
Parameters
----------
run_setup: list of tuples
"""
logging.info('Using run-setup as follow:')
logging.info('Stage | nburn | nprod | nsegs | resetp0')
for i, rs in enumerate(run_setup):
rs = list(rs)
if len(rs) == 2:
rs.append(False)
if np.shape(rs[0]) == ():
rs[0] = (rs[0], 0)
run_setup[i] = rs
logging.info('{} | {} | {} | {} | {}'.format(
str(i).ljust(5), str(rs[0][0]).ljust(5),
str(rs[0][1]).ljust(5), str(rs[1]).ljust(5), rs[2]))
self.run_setup = run_setup
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
if self.old_data_is_okay_to_use is True:
logging.warning('Using saved data from {}'.format(
self.pickle_path))
d = self.get_saved_data()
self.sampler = d['sampler']
self.samples = d['samples']
self.lnprobs = d['lnprobs']
self.lnlikes = d['lnlikes']
return
fig = None
axes = None
nsteps_total = 0
for j, ((nburn, nprod), nseg, reset_p0) in enumerate(run_setup):
if j == 0:
p0 = self.generate_initial_p0()
p0 = self.apply_corrections_to_p0(p0)
elif reset_p0:
p0 = self.get_new_p0(sampler)
p0 = self.apply_corrections_to_p0(p0)
# self.check_initial_points(p0)
else:
p0 = sampler.chain[:, :, -1, :]
self.nsegs = nseg
self.inititate_search_object()
sampler = emcee.PTSampler(
self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
logpargs=(self.theta_prior, self.theta_keys, self.search),
loglargs=(self.search,), betas=self.betas,
a=proposal_scale_factor)
logging.info('Running {}/{} with {} steps and {} nsegs'.format(
j+1, len(self.nsteps), (nburn, nprod), nseg))
sampler = self.run_sampler_with_progress_bar(
sampler, nburn+nprod, p0)
logging.info("Mean acceptance fraction: {}"
.format(np.mean(sampler.acceptance_fraction, axis=1)))
if self.ntemps > 1:
logging.info("Tswap acceptance fraction: {}"
.format(sampler.tswap_acceptance_fraction))
fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
fig=fig, axes=axes, burnin_idx=nburn,
xoffset=nsteps_total)
for ax in axes[:-1]:
ax.axvline(nsteps_total, color='k', ls='--')
nsteps_total += nburn+nprod
fig.savefig('{}/{}_walkers.png'.format(
self.outdir, self.label))
samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1))
lnlikes = sampler.lnlikelihood[0, :, nburn:].reshape((-1))
self.sampler = sampler
self.samples = samples
self.lnprobs = lnprobs
self.lnlikes = lnlikes
self.save_data(sampler, samples, lnprobs, lnlikes)
class MCMCTransientSearch(MCMCSearch):
""" MCMC search for a transient signal using the ComputeFstat """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment