From 2a5f652a67dac9c8bb69732527226184de58f00f Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Mon, 2 Oct 2017 10:05:20 +0200
Subject: [PATCH] Improvements to tests

---
 tests.py | 79 +++++++++++++++++++++++++++++++-------------------------
 1 file changed, 44 insertions(+), 35 deletions(-)

diff --git a/tests.py b/tests.py
index 4ece9d9..a27ec82 100644
--- a/tests.py
+++ b/tests.py
@@ -6,32 +6,36 @@ import pyfstat
 
 
 class Test(unittest.TestCase):
+    outdir = 'TestData'
+
     @classmethod
-    def setUpClass(cls):
-        pass
+    def setUpClass(self):
+        if os.path.isdir(self.outdir):
+            shutil.rmtree(self.outdir)
 
     @classmethod
-    def tearDownClass(cls):
-        pass
+    def tearDownClass(self):
+        if os.path.isdir(self.outdir):
+            shutil.rmtree(self.outdir)
 
 
 class TestWriter(Test):
     label = "TestWriter"
 
     def test_make_cff(self):
-        Writer = pyfstat.Writer(self.label, outdir=outdir)
+        Writer = pyfstat.Writer(self.label, outdir=self.outdir)
         Writer.make_cff()
         self.assertTrue(os.path.isfile('./TestData/{}.cff'.format(self.label)))
 
     def test_run_makefakedata(self):
-        Writer = pyfstat.Writer(self.label, outdir=outdir, duration=86400)
+        Writer = pyfstat.Writer(self.label, outdir=self.outdir, duration=86400)
         Writer.make_cff()
         Writer.run_makefakedata()
         self.assertTrue(os.path.isfile(
             './TestData/H-48_H1_1800SFT_TestWriter-700000000-86400.sft'))
 
     def test_makefakedata_usecached(self):
-        Writer = pyfstat.Writer(self.label, outdir=outdir, duration=86400)
+        Writer = pyfstat.Writer(self.label, outdir=self.outdir, duration=86400)
         if os.path.isfile(Writer.sftfilepath):
             os.remove(Writer.sftfilepath)
         Writer.make_cff()
@@ -46,6 +50,32 @@ class TestWriter(Test):
         self.assertFalse(time_first == time_third)
 
 
+class TestBunch(Test):
+    def test_bunch(self):
+        b = pyfstat.core.Bunch(dict(x=10))
+        self.assertTrue(b.x == 10)
+
+
+class TestPar(Test):
+    label = 'TestPar'
+
+    def test(self):
+        os.system('mkdir {}'.format(self.outdir))
+        os.system(
+            'echo "x=100\ny=10" > {}/{}.par'.format(self.outdir, self.label))
+
+        par = pyfstat.core.read_par(
+            '{}/{}.par'.format(self.outdir, self.label), return_type='Bunch')
+        self.assertTrue(par.x == 100)
+        self.assertTrue(par.y == 10)
+
+        par = pyfstat.core.read_par(outdir=self.outdir, label=self.label,
+                                    return_type='dict')
+        self.assertTrue(par['x'] == 100)
+        self.assertTrue(par['y'] == 10)
+        os.system('rm -r TestData')
+
+
 class TestBaseSearchClass(Test):
     def test_shift_matrix(self):
         BSC = pyfstat.BaseSearchClass()
@@ -89,7 +119,7 @@ class TestComputeFstat(Test):
     label = "TestComputeFstat"
 
     def test_run_computefstatistic_single_point(self):
-        Writer = pyfstat.Writer(self.label, outdir=outdir, duration=86400,
+        Writer = pyfstat.Writer(self.label, outdir=self.outdir, duration=86400,
                                 h0=1, sqrtSX=1)
         Writer.make_data()
         predicted_FS = Writer.predict_fstat()
@@ -104,7 +134,7 @@ class TestComputeFstat(Test):
         self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.2)
 
     def run_computefstatistic_single_point_no_noise(self):
-        Writer = pyfstat.Writer(self.label, outdir=outdir, add_noise=False,
+        Writer = pyfstat.Writer(self.label, outdir=self.outdir, add_noise=False,
                                 duration=86400, h0=1, sqrtSX=1)
         Writer.make_data()
         predicted_FS = Writer.predict_fstat()
@@ -119,7 +149,7 @@ class TestComputeFstat(Test):
         self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.2)
 
     def test_injectSources_from_file(self):
-        Writer = pyfstat.Writer(self.label, outdir=outdir, add_noise=False,
+        Writer = pyfstat.Writer(self.label, outdir=self.outdir, add_noise=False,
                                 duration=86400, h0=1, sqrtSX=1)
         Writer.make_cff()
         injectSources = Writer.config_file_name
@@ -148,13 +178,13 @@ class TestSemiCoherentGlitchSearch(Test):
         h0 = 1
         sqrtSX = 1
         Writer = pyfstat.GlitchWriter(
-            self.label, outdir=outdir, duration=duration, dtglitch=dtglitch,
+            self.label, outdir=self.outdir, duration=duration, dtglitch=dtglitch,
             delta_F0=delta_F0, sqrtSX=sqrtSX, h0=h0)
 
         Writer.make_data()
 
         search = pyfstat.SemiCoherentGlitchSearch(
-            label=self.label, outdir=outdir,
+            label=self.label, outdir=self.outdir,
             sftfilepattern='{}/*{}*sft'.format(Writer.outdir, Writer.label),
             tref=Writer.tref, minStartTime=Writer.tstart,
             maxStartTime=Writer.tend, nglitch=1)
@@ -198,7 +228,7 @@ class TestMCMCSearch(Test):
         tref = minStartTime
         Writer = pyfstat.Writer(F0=F0, F1=F1, F2=F2, label=self.label,
                                 h0=h0, sqrtSX=sqrtSX,
-                                outdir=outdir, tstart=minStartTime,
+                                outdir=self.outdir, tstart=minStartTime,
                                 Alpha=Alpha, Delta=Delta, tref=tref,
                                 duration=duration,
                                 Band=4)
@@ -211,7 +241,7 @@ class TestMCMCSearch(Test):
                  'F2': F2, 'Alpha': Alpha, 'Delta': Delta}
 
         search = pyfstat.MCMCSearch(
-            label=self.label, outdir=outdir, theta_prior=theta, tref=tref,
+            label=self.label, outdir=self.outdir, theta_prior=theta, tref=tref,
             sftfilepattern='{}/*{}*sft'.format(Writer.outdir, Writer.label),
             minStartTime=minStartTime, maxStartTime=maxStartTime,
             nsteps=[100, 100], nwalkers=100, ntemps=2, log10beta_min=-1)
@@ -224,27 +254,6 @@ class TestMCMCSearch(Test):
         self.assertTrue(
             FS > predicted_FS or np.abs((FS-predicted_FS))/predicted_FS < 0.3)
 
-    def test_multi_stage(self):
-        Writer = pyfstat.Writer(F0=10, duration=86400, h0=1, sqrtSX=1)
-        Writer.make_cff()
-
-        theta = {'F0': {'type': 'norm', 'loc': 10, 'scale': 1e-2},
-                 'F1': 0, 'F2': 0, 'Alpha': 0, 'Delta': 0}
-
-        search = pyfstat.MCMCSearch(
-            label=self.label, outdir=outdir, theta_prior=theta,
-            tref=Writer.tref, injectSources=Writer.config_file_name,
-            minStartTime=Writer.minStartTime, maxStartTime=Writer.maxStartTime,
-            nsteps=[5, 5], nwalkers=20, ntemps=1, detectors='H1',
-            minCoverFreq=9, maxCoverFreq=11)
-        search.run(create_plots=False)
-
 
 if __name__ == '__main__':
-    outdir = 'TestData'
-    if os.path.isdir(outdir):
-        shutil.rmtree(outdir)
     unittest.main()
-    if os.path.isdir(outdir):
-        shutil.rmtree(outdir)
-
-- 
GitLab