From 629667a37ddd334375f5f6afe546b0d8a1fe1cd3 Mon Sep 17 00:00:00 2001
From: "gregory.ashton" <gregory.ashton@ligo.org>
Date: Tue, 21 Feb 2017 15:41:32 +0100
Subject: [PATCH] Fix multi-stage set-up when using convergence testing (adds
 test)

---
 pyfstat/mcmc_based_searches.py |  2 +-
 tests.py                       | 15 +++++++++++++++
 2 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index cf548a9..dacfbe0 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -638,7 +638,7 @@ class MCMCSearch(BaseSearchClass):
             raise ValueError("dist_type {} unknown".format(dist_type))
 
     def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
-                     lw=0.1, nprod=None, add_det_stat_burnin=False,
+                     lw=0.1, nprod=0, add_det_stat_burnin=False,
                      fig=None, axes=None, xoffset=0, plot_det_stat=True,
                      context='classic', subtractions=None, labelpad=0.05):
         """ Plot all the chains from a sampler """
diff --git a/tests.py b/tests.py
index 64d1965..69e3320 100644
--- a/tests.py
+++ b/tests.py
@@ -232,6 +232,21 @@ class TestMCMCSearch(Test):
         self.assertTrue(
             FS > predicted_FS or np.abs((FS-predicted_FS))/predicted_FS < 0.3)
 
+    def test_multi_stage(self):
+        Writer = pyfstat.Writer()
+        Writer.make_cff()
+
+        theta = {'F0': {'type': 'norm', 'loc': 10, 'scale': 1e-2},
+                 'F1': 0, 'F2': 0, 'Alpha': 0, 'Delta': 0}
+
+        search = pyfstat.MCMCSearch(
+            label=self.label, outdir=outdir, theta_prior=theta,
+            tref=Writer.tref, injectSources=Writer.config_file_name,
+            minStartTime=Writer.minStartTime, maxStartTime=Writer.maxStartTime,
+            nsteps=[5, 5], nwalkers=20, ntemps=1, detectors='H1',
+            minCoverFreq=9, maxCoverFreq=11)
+        search.run(create_plots=False)
+
 
 class TestAuxillaryFunctions(Test):
     nsegs = 10
-- 
GitLab