From b277b235e94de7235e30c4468c1a25c4c103a28a Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Fri, 17 Nov 2017 10:12:34 +0100
Subject: [PATCH] Move from emcee to ptemcee

Previously, we used the emcee.PTSampler, but this has been removed from
the master branch of emcee (see discussion here
https://github.com/dfm/emcee/issues/236) and a fork, ptemcee has been
developed. Testing shows this is equivalent (after some changes in the
interface) and perhaps better as it contains the adaptive temperature
ladders developed by Will Vousden (the maintainer of ptemcee)
---
 README.md                      |  2 +-
 pyfstat/mcmc_based_searches.py | 67 +++++++++++++++++-----------------
 requirements.txt               |  2 +-
 3 files changed, 35 insertions(+), 36 deletions(-)

diff --git a/README.md b/README.md
index 1de211c..7394094 100644
--- a/README.md
+++ b/README.md
@@ -46,7 +46,7 @@ $ git clone https://gitlab.aei.uni-hannover.de/GregAshton/PyFstat.git
 * [numpy](http://www.numpy.org/)
 * [matplotlib](http://matplotlib.org/) >= 1.4
 * [scipy](https://www.scipy.org/)
-* [emcee](http://dan.iel.fm/emcee/current/)
+* [ptemcee](https://github.com/willvousden/ptemcee)
 * [corner](https://pypi.python.org/pypi/corner/)
 * [dill](https://pypi.python.org/pypi/dill)
 * [peakutils](https://pypi.python.org/pypi/PeakUtils)
diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index e113948..3613994 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -11,7 +11,7 @@ import subprocess
 import numpy as np
 import matplotlib
 import matplotlib.pyplot as plt
-import emcee
+from ptemcee import Sampler as PTSampler
 import corner
 import dill as pickle
 
@@ -54,7 +54,7 @@ class MCMCSearch(core.BaseSearchClass):
     log10beta_min float < 0, optional
         The  log_10(beta) value, if given the set of betas passed to PTSampler
         are generated from `np.logspace(0, log10beta_min, ntemps)` (given
-        in descending order to emcee).
+        in descending order to ptemcee).
     theta_initial: dict, array, optional
         A dictionary of distribution about which to distribute the
         initial walkers about
@@ -396,14 +396,9 @@ class MCMCSearch(core.BaseSearchClass):
             self.tswap_acceptance_fraction = sampler.tswap_acceptance_fraction
             logging.info("Tswap acceptance fraction: {}"
                          .format(sampler.tswap_acceptance_fraction))
-        try:
-            self.autocorr_time = sampler.get_autocorr_time(c=4)
-            logging.info("Autocorrelation length: {}".format(
-                self.autocorr_time))
-        except emcee.autocorr.AutocorrError as e:
-            self.autocorr_time = np.nan
-            logging.warning(
-                'Autocorrelation calculation failed with message {}'.format(e))
+        self.autocorr_time = sampler.get_autocorr_time(window=window)
+        logging.info("Autocorrelation length: {}".format(
+            self.autocorr_time))
 
         return sampler
 
@@ -430,7 +425,8 @@ class MCMCSearch(core.BaseSearchClass):
         logging.info('Estimated run-time = {} s = {:1.0f}:{:1.0f} m'.format(
             a+b, *divmod(a+b, 60)))
 
-    def run(self, proposal_scale_factor=2, create_plots=True, c=5, **kwargs):
+    def run(self, proposal_scale_factor=2, create_plots=True, window=50,
+            **kwargs):
         """ Run the MCMC simulatation
 
         Parameters
@@ -442,17 +438,17 @@ class MCMCSearch(core.BaseSearchClass):
             it by increasing the a parameter [Foreman-Mackay (2013)].
         create_plots: bool
             If true, save trace plots of the walkers
-        c: int
+        window: int
             The minimum number of autocorrelation times needed to trust the
             result when estimating the autocorrelation time (see
-            emcee.autocorr.integrated_time for further details. Default is 5
+            ptemcee.Sampler.get_autocorr_time for further details.
         **kwargs:
             Passed to _plot_walkers to control the figures
 
         Returns
         -------
-        sampler: emcee.ptsampler.PTSampler
-            The emcee ptsampler object
+        sampler: ptemcee.Sampler
+            The ptemcee ptsampler object
 
         """
 
@@ -470,8 +466,9 @@ class MCMCSearch(core.BaseSearchClass):
         self._initiate_search_object()
         self._estimate_run_time()
 
-        sampler = emcee.PTSampler(
-            self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
+        sampler = PTSampler(
+            ntemps=self.ntemps, nwalkers=self.nwalkers, dim=self.ndim,
+            logl=self.logl, logp=self.logp,
             logpargs=(self.theta_prior, self.theta_keys, self.search),
             loglargs=(self.search,), betas=self.betas, a=proposal_scale_factor)
 
@@ -484,7 +481,7 @@ class MCMCSearch(core.BaseSearchClass):
         for j, n in enumerate(self.nsteps[:-2]):
             logging.info('Running {}/{} initialisation with {} steps'.format(
                 j, ninit_steps, n))
-            sampler = self._run_sampler(sampler, p0, nburn=n)
+            sampler = self._run_sampler(sampler, p0, nburn=n, window=window)
             if create_plots:
                 fig, axes = self._plot_walkers(sampler,
                                                symbols=self.theta_symbols,
@@ -514,9 +511,9 @@ class MCMCSearch(core.BaseSearchClass):
                         )
 
         samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
-        lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1))
-        lnlikes = sampler.lnlikelihood[0, :, nburn:].reshape((-1))
-        all_lnlikelihood = sampler.lnlikelihood[:, :, nburn:]
+        lnprobs = sampler.logprobability[0, :, nburn:].reshape((-1))
+        lnlikes = sampler.loglikelihood[0, :, nburn:].reshape((-1))
+        all_lnlikelihood = sampler.loglikelihood[:, :, nburn:]
         self.samples = samples
         self.lnprobs = lnprobs
         self.lnlikes = lnlikes
@@ -1053,7 +1050,7 @@ class MCMCSearch(core.BaseSearchClass):
                 if len(axes) == ndim:
                     axes.append(fig.add_subplot(ndim+1, 1, ndim+1))
 
-                lnl = sampler.lnlikelihood[temp, :, :]
+                lnl = sampler.loglikelihood[temp, :, :]
                 if burnin_idx and add_det_stat_burnin:
                     burn_in_vals = lnl[:, :burnin_idx].flatten()
                     try:
@@ -1135,8 +1132,8 @@ class MCMCSearch(core.BaseSearchClass):
         """
         temp_idx = 0
         pF = sampler.chain[temp_idx, :, :, :]
-        lnl = sampler.lnlikelihood[temp_idx, :, :]
-        lnp = sampler.lnprobability[temp_idx, :, :]
+        lnl = sampler.loglikelihood[temp_idx, :, :]
+        lnp = sampler.logprobability[temp_idx, :, :]
 
         # General warnings about the state of lnp
         if np.any(np.isnan(lnp)):
@@ -2027,7 +2024,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
 
     def run(self, run_setup=None, proposal_scale_factor=2, NstarMax=10,
             Nsegs0=None, create_plots=True, log_table=True, gen_tex_table=True,
-            fig=None, axes=None, return_fig=False, **kwargs):
+            fig=None, axes=None, return_fig=False, window=50, **kwargs):
         """ Run the follow-up with the given run_setup
 
         Parameters
@@ -2040,10 +2037,10 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             it by increasing the a parameter [Foreman-Mackay (2013)].
         create_plots: bool
             If true, save trace plots of the walkers
-        c: int
+        window: int
             The minimum number of autocorrelation times needed to trust the
             result when estimating the autocorrelation time (see
-            emcee.autocorr.integrated_time for further details. Default is 5
+            ptemcee.Sampler.get_autocorr_time for further details.
         **kwargs:
             Passed to _plot_walkers to control the figures
 
@@ -2085,8 +2082,9 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             self.search.nsegs = nseg
             self.update_search_object()
             self.search.init_semicoherent_parameters()
-            sampler = emcee.PTSampler(
-                self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
+            sampler = PTSampler(
+                ntemps=self.ntemps, nwalkers=self.nwalkers, dim=self.ndim,
+                logl=self.logl, logp=self.logp,
                 logpargs=(self.theta_prior, self.theta_keys, self.search),
                 loglargs=(self.search,), betas=self.betas,
                 a=proposal_scale_factor)
@@ -2095,9 +2093,10 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             logging.info(('Running {}/{} with {} steps and {} nsegs '
                           '(Tcoh={:1.2f} days)').format(
                 j+1, len(run_setup), (nburn, nprod), nseg, Tcoh))
-            sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
+            sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod,
+                                        window=window)
             logging.info('Max detection statistic of run was {}'.format(
-                np.max(sampler.lnlikelihood)))
+                np.max(sampler.loglikelihood)))
 
             if create_plots:
                 fig, axes = self._plot_walkers(
@@ -2109,9 +2108,9 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             nsteps_total += nburn+nprod
 
         samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
-        lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1))
-        lnlikes = sampler.lnlikelihood[0, :, nburn:].reshape((-1))
-        all_lnlikelihood = sampler.lnlikelihood
+        lnprobs = sampler.logprobability[0, :, nburn:].reshape((-1))
+        lnlikes = sampler.loglikelihood[0, :, nburn:].reshape((-1))
+        all_lnlikelihood = sampler.loglikelihood
         self.samples = samples
         self.lnprobs = lnprobs
         self.lnlikes = lnlikes
diff --git a/requirements.txt b/requirements.txt
index 3ebebbd..eee946d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,7 @@
 numpy
 matplotlib>=1.4
 scipy
-emcee
+ptemcee
 corner
 dill
 tqdm
-- 
GitLab