diff --git a/docs/img/follow_up_walkers.png b/docs/img/follow_up_walkers.png new file mode 100644 index 0000000000000000000000000000000000000000..70643f6882927a50fd7e527bec389ee51417c6f0 Binary files /dev/null and b/docs/img/follow_up_walkers.png differ diff --git a/examples/follow_up.py b/examples/follow_up.py new file mode 100644 index 0000000000000000000000000000000000000000..d76dc6d61a868e450aca2b4dc813487bca981a1a --- /dev/null +++ b/examples/follow_up.py @@ -0,0 +1,34 @@ +import pyfstat + +F0 = 30.0 +F1 = -1e-10 +F2 = 0 +Alpha = 5e-3 +Delta = 6e-2 +tref = 362750407.0 + +tstart = 1000000000 +duration = 100*86400 +tend = tstart + duration + +theta_prior = {'F0': {'type': 'unif', 'lower': F0*(1-1e-6), 'upper': F0*(1+1e-5)}, + 'F1': {'type': 'unif', 'lower': F1*(1+1e-2), 'upper': F1*(1-1e-2)}, + 'F2': F2, + 'Alpha': Alpha, + 'Delta': Delta + } + +ntemps = 1 +log10temperature_min = -1 +nwalkers = 100 +run_setup = [(500, 50), (500, 25), (100, 1, False), + ((100, 100), 1, True)] + +mcmc = pyfstat.MCMCFollowUpSearch( + label='follow_up', outdir='data', + sftfilepath='data/*basic*sft', theta_prior=theta_prior, tref=tref, + minStartTime=tstart, maxStartTime=tend, nwalkers=nwalkers, + ntemps=ntemps, log10temperature_min=log10temperature_min) +mcmc.run(run_setup) +mcmc.plot_corner(add_prior=True) +mcmc.print_summary() diff --git a/pyfstat.py b/pyfstat.py index 0fba2e3918c942ac36ec706f2154ef98156c4093..d1fb7bf21e6d6c2822b65800d094afcea2ae1d6a 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -477,15 +477,68 @@ class SemiCoherentSearch(BaseSearchClass, ComputeFstat): self.tboundaries = np.linspace(self.minStartTime, self.maxStartTime, self.nsegs+1) - def compute_nseg_fstat(self, F0, F1, F2, Alpha, Delta): - """ Returns the semi-coherent summed twoF """ + def run_semi_coherent_computefstatistic_single_point( + self, F0, F1, F2, Alpha, Delta, asini=None, + period=None, ecc=None, tp=None, argp=None): + """ Returns twoF or ln(BSGL) semi-coherently at a single point """ - twoFvals = [self.run_computefstatistic_single_point( - self.tboundaries[i], self.tboundaries[i+1], F0, F1, F2, Alpha, - Delta) - for i in range(self.nsegs)] + self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0]) + self.PulsarDopplerParams.Alpha = Alpha + self.PulsarDopplerParams.Delta = Delta + if self.binary: + self.PulsarDopplerParams.asini = asini + self.PulsarDopplerParams.period = period + self.PulsarDopplerParams.ecc = ecc + self.PulsarDopplerParams.tp = tp + self.PulsarDopplerParams.argp = argp + + lalpulsar.ComputeFstat(self.FstatResults, + self.FstatInput, + self.PulsarDopplerParams, + 1, + self.whatToCompute + ) + + if self.transient is False: + if self.BSGL is False: + return self.FstatResults.twoF[0] + + twoF = np.float(self.FstatResults.twoF[0]) + self.twoFX[0] = self.FstatResults.twoFPerDet(0) + self.twoFX[1] = self.FstatResults.twoFPerDet(1) + log10_BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX, + self.BSGLSetup) + return log10_BSGL/np.log10(np.exp(1)) + + detStat = 0 + for tstart, tend in zip(self.tboundaries[:-1], self.tboundaries[1:]): + self.windowRange.t0 = int(tstart) # TYPE UINT4 + self.windowRange.tau = int(tend - tstart) # TYPE UINT4 + + FS = lalpulsar.ComputeTransientFstatMap( + self.FstatResults.multiFatoms[0], self.windowRange, False) + + if self.BSGL is False: + detStat += 2*FS.F_mn.data[0][0] + continue - return np.sum(twoFvals) + FstatResults_single = copy.copy(self.FstatResults) + FstatResults_single.lenth = 1 + FstatResults_single.data = self.FstatResults.multiFatoms[0].data[0] + FS0 = lalpulsar.ComputeTransientFstatMap( + FstatResults_single.multiFatoms[0], self.windowRange, False) + FstatResults_single.data = self.FstatResults.multiFatoms[0].data[1] + FS1 = lalpulsar.ComputeTransientFstatMap( + FstatResults_single.multiFatoms[0], self.windowRange, False) + + self.twoFX[0] = 2*FS0.F_mn.data[0][0] + self.twoFX[1] = 2*FS1.F_mn.data[0][0] + log10_BSGL = lalpulsar.ComputeBSGL( + 2*FS.F_mn.data[0][0], self.twoFX, self.BSGLSetup) + + detStat += log10_BSGL/np.log10(np.exp(1)) + + return detStat class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): @@ -669,7 +722,6 @@ class MCMCSearch(BaseSearchClass): if args.clean and os.path.isfile(self.pickle_path): os.rename(self.pickle_path, self.pickle_path+".old") - self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use() self.log_input() def log_input(self): @@ -789,6 +841,7 @@ class MCMCSearch(BaseSearchClass): def run(self, proposal_scale_factor=2): + self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use() if self.old_data_is_okay_to_use is True: logging.warning('Using saved data from {}'.format( self.pickle_path)) @@ -813,7 +866,7 @@ class MCMCSearch(BaseSearchClass): ninit_steps = len(self.nsteps) - 2 for j, n in enumerate(self.nsteps[:-2]): logging.info('Running {}/{} initialisation with {} steps'.format( - j+1, ninit_steps, n)) + j, ninit_steps, n)) sampler = self.run_sampler_with_progress_bar(sampler, n, p0) logging.info("Mean acceptance fraction: {}" .format(np.mean(sampler.acceptance_fraction, axis=1))) @@ -1083,7 +1136,8 @@ class MCMCSearch(BaseSearchClass): raise ValueError("dist_type {} unknown".format(dist_type)) def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0, - lw=0.1, burnin_idx=None, add_det_stat_burnin=False): + lw=0.1, burnin_idx=None, add_det_stat_burnin=False, + fig=None, axes=None, xoffset=0, plot_det_stat=True): """ Plot all the chains from a sampler """ shape = sampler.chain.shape @@ -1100,10 +1154,11 @@ class MCMCSearch(BaseSearchClass): chain = sampler.chain[temp, :, :, :] with plt.style.context(('classic')): - fig = plt.figure(figsize=(8, 4*ndim)) - ax = fig.add_subplot(ndim+1, 1, 1) - axes = [ax] + [fig.add_subplot(ndim+1, 1, i, sharex=ax) - for i in range(2, ndim+1)] + if fig is None and axes is None: + fig = plt.figure(figsize=(8, 4*ndim)) + ax = fig.add_subplot(ndim+1, 1, 1) + axes = [ax] + [fig.add_subplot(ndim+1, 1, i, sharex=ax) + for i in range(2, ndim+1)] idxs = np.arange(chain.shape[1]) if ndim > 1: @@ -1111,10 +1166,11 @@ class MCMCSearch(BaseSearchClass): axes[i].ticklabel_format(useOffset=False, axis='y') cs = chain[:, :, i].T if burnin_idx: - axes[i].plot(idxs[:burnin_idx], cs[:burnin_idx], - color="r", alpha=alpha, lw=lw) - axes[i].plot(idxs[burnin_idx:], cs[burnin_idx:], color="k", - alpha=alpha, lw=lw) + axes[i].plot(xoffset+idxs[:burnin_idx], + cs[:burnin_idx], color="r", alpha=alpha, + lw=lw) + axes[i].plot(xoffset+idxs[burnin_idx:], cs[burnin_idx:], + color="k", alpha=alpha, lw=lw) if symbols: axes[i].set_ylabel(symbols[i]) else: @@ -1128,19 +1184,30 @@ class MCMCSearch(BaseSearchClass): if symbols: axes[0].set_ylabel(symbols[0]) - axes.append(fig.add_subplot(ndim+1, 1, ndim+1)) - lnl = sampler.lnlikelihood[temp, :, :] - if burnin_idx and add_det_stat_burnin: - vals = lnl[:, :burnin_idx].flatten() - axes[-1].hist(vals[~np.isnan(vals)], bins=50, histtype='step', - color='r') - vals = lnl[:, burnin_idx:].flatten() - axes[-1].hist(vals[~np.isnan(vals)], bins=50, histtype='step', - color='k') - if self.BSGL: - axes[-1].set_xlabel(r'$\mathcal{B}_\mathrm{S/GL}$') - else: - axes[-1].set_xlabel(r'$2\mathcal{F}$') + if len(axes) == ndim: + axes.append(fig.add_subplot(ndim+1, 1, ndim+1)) + + if plot_det_stat: + lnl = sampler.lnlikelihood[temp, :, :] + if burnin_idx and add_det_stat_burnin: + burn_in_vals = lnl[:, :burnin_idx].flatten() + axes[-1].hist(burn_in_vals[~np.isnan(burn_in_vals)], bins=50, + histtype='step', color='r') + else: + burn_in_vals = [] + prod_vals = lnl[:, burnin_idx:].flatten() + axes[-1].hist(prod_vals[~np.isnan(prod_vals)], bins=50, + histtype='step', color='k') + if self.BSGL: + axes[-1].set_xlabel(r'$\mathcal{B}_\mathrm{S/GL}$') + else: + axes[-1].set_xlabel(r'$2\mathcal{F}$') + combined_vals = np.append(burn_in_vals, prod_vals) + if len(combined_vals) > 0: + minv = np.min(combined_vals) + maxv = np.max(combined_vals) + Range = abs(maxv-minv) + axes[-1].set_xlim(minv-0.1*Range, maxv+0.1*Range) return fig, axes @@ -1461,7 +1528,6 @@ class MCMCSearch(BaseSearchClass): def compute_evidence(self): """ Computes the evidence/marginal likelihood for the model """ fburnin = float(self.nsteps[-2])/np.sum(self.nsteps[-2:]) - print fburnin lnev, lnev_err = self.sampler.thermodynamic_integration_log_evidence( fburnin=fburnin) @@ -1794,7 +1860,6 @@ class MCMCSemiCoherentSearch(MCMCSearch): if args.clean and os.path.isfile(self.pickle_path): os.rename(self.pickle_path, self.pickle_path+".old") - self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use() self.log_input() def inititate_search_object(self): @@ -1815,10 +1880,107 @@ class MCMCSemiCoherentSearch(MCMCSearch): def logl(self, theta, search): for j, theta_i in enumerate(self.theta_idxs): self.fixed_theta[theta_i] = theta[j] - FS = search.compute_nseg_fstat(*self.fixed_theta) + FS = search.run_semi_coherent_computefstatistic_single_point( + *self.fixed_theta) return FS +class MCMCFollowUpSearch(MCMCSemiCoherentSearch): + """ A follow up procudure increasing the coherence time in a zoom """ + def get_save_data_dictionary(self): + d = dict(nwalkers=self.nwalkers, ntemps=self.ntemps, + theta_keys=self.theta_keys, theta_prior=self.theta_prior, + scatter_val=self.scatter_val, + log10temperature_min=self.log10temperature_min, + BSGL=self.BSGL, run_setup=self.run_setup) + return d + + def run(self, run_setup, proposal_scale_factor=2): + """ Run the follow-up with the given run_setup + + Parameters + ---------- + run_setup: list of tuples + + """ + + logging.info('Using run-setup as follow:') + logging.info('Stage | nburn | nprod | nsegs | resetp0') + for i, rs in enumerate(run_setup): + rs = list(rs) + if len(rs) == 2: + rs.append(False) + if np.shape(rs[0]) == (): + rs[0] = (rs[0], 0) + run_setup[i] = rs + logging.info('{} | {} | {} | {} | {}'.format( + str(i).ljust(5), str(rs[0][0]).ljust(5), + str(rs[0][1]).ljust(5), str(rs[1]).ljust(5), rs[2])) + + self.run_setup = run_setup + self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use() + if self.old_data_is_okay_to_use is True: + logging.warning('Using saved data from {}'.format( + self.pickle_path)) + d = self.get_saved_data() + self.sampler = d['sampler'] + self.samples = d['samples'] + self.lnprobs = d['lnprobs'] + self.lnlikes = d['lnlikes'] + return + + fig = None + axes = None + nsteps_total = 0 + for j, ((nburn, nprod), nseg, reset_p0) in enumerate(run_setup): + if j == 0: + p0 = self.generate_initial_p0() + p0 = self.apply_corrections_to_p0(p0) + elif reset_p0: + p0 = self.get_new_p0(sampler) + p0 = self.apply_corrections_to_p0(p0) + # self.check_initial_points(p0) + else: + p0 = sampler.chain[:, :, -1, :] + + self.nsegs = nseg + self.inititate_search_object() + sampler = emcee.PTSampler( + self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp, + logpargs=(self.theta_prior, self.theta_keys, self.search), + loglargs=(self.search,), betas=self.betas, + a=proposal_scale_factor) + + logging.info('Running {}/{} with {} steps and {} nsegs'.format( + j+1, len(self.nsteps), (nburn, nprod), nseg)) + sampler = self.run_sampler_with_progress_bar( + sampler, nburn+nprod, p0) + logging.info("Mean acceptance fraction: {}" + .format(np.mean(sampler.acceptance_fraction, axis=1))) + if self.ntemps > 1: + logging.info("Tswap acceptance fraction: {}" + .format(sampler.tswap_acceptance_fraction)) + + fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols, + fig=fig, axes=axes, burnin_idx=nburn, + xoffset=nsteps_total) + for ax in axes[:-1]: + ax.axvline(nsteps_total, color='k', ls='--') + nsteps_total += nburn+nprod + + fig.savefig('{}/{}_walkers.png'.format( + self.outdir, self.label)) + + samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim)) + lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1)) + lnlikes = sampler.lnlikelihood[0, :, nburn:].reshape((-1)) + self.sampler = sampler + self.samples = samples + self.lnprobs = lnprobs + self.lnlikes = lnlikes + self.save_data(sampler, samples, lnprobs, lnlikes) + + class MCMCTransientSearch(MCMCSearch): """ MCMC search for a transient signal using the ComputeFstat """