diff --git a/pyfstat.py b/pyfstat.py index 18368b16baf1fbfe94934fc0c429955b2092e954..d84d6ac1fcf5ce36652bb9d286c65c82d547a356 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -828,7 +828,8 @@ class MCMCSearch(BaseSearchClass): def get_save_data_dictionary(self): d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers, ntemps=self.ntemps, theta_keys=self.theta_keys, - theta_prior=self.theta_prior, scatter_val=self.scatter_val) + theta_prior=self.theta_prior, scatter_val=self.scatter_val, + log10temperature_min=self.log10temperature_min) return d def save_data(self, sampler, samples, lnprobs, lnlikes): diff --git a/tests.py b/tests.py index 176b75f844dd252ba21d4eceb5cea46720f38183..1a9ccf65af4a20af3d4fa453ad1ce94ee5aee264 100644 --- a/tests.py +++ b/tests.py @@ -1,28 +1,37 @@ import unittest -import pyfstat import numpy as np import os +import shutil +import pyfstat + +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + pass -class TestWriter(unittest.TestCase): + @classmethod + def tearDownClass(cls): + pass + + +class TestWriter(Test): + label = "Test" def test_make_cff(self): - label = "Test" - Writer = pyfstat.Writer(label, outdir='TestData') + Writer = pyfstat.Writer(self.label, outdir=outdir) Writer.make_cff() self.assertTrue(os.path.isfile('./TestData/Test.cff')) def test_run_makefakedata(self): - label = "Test" - Writer = pyfstat.Writer(label, outdir='TestData') + Writer = pyfstat.Writer(self.label, outdir=outdir) Writer.make_cff() Writer.run_makefakedata() self.assertTrue(os.path.isfile( './TestData/H-4800_H1_1800SFT_Test-700000000-8640000.sft')) def test_makefakedata_usecached(self): - label = "Test" - Writer = pyfstat.Writer(label, outdir='TestData') + Writer = pyfstat.Writer(self.label, outdir=outdir) if os.path.isfile(Writer.sft_filepath): os.remove(Writer.sft_filepath) Writer.run_makefakedata() @@ -36,7 +45,7 @@ class TestWriter(unittest.TestCase): self.assertFalse(time_first == time_third) -class TestBaseSearchClass(unittest.TestCase): +class TestBaseSearchClass(Test): def test_shift_matrix(self): BSC = pyfstat.BaseSearchClass() dT = 10 @@ -75,12 +84,11 @@ class TestBaseSearchClass(unittest.TestCase): rtol=1e-9, atol=1e-9)) -class TestComputeFstat(unittest.TestCase): +class TestComputeFstat(Test): label = "Test" - outdir = 'TestData' def test_run_computefstatistic_single_point(self): - Writer = pyfstat.Writer(self.label, outdir=self.outdir) + Writer = pyfstat.Writer(self.label, outdir=outdir) Writer.make_data() predicted_FS = Writer.predict_fstat() @@ -97,15 +105,14 @@ class TestComputeFstat(unittest.TestCase): self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.1) -class TestSemiCoherentGlitchSearch(unittest.TestCase): +class TestSemiCoherentGlitchSearch(Test): label = "Test" - outdir = 'TestData' def test_compute_nglitch_fstat(self): duration = 100*86400 dtglitch = 100*43200 delta_F0 = 0 - Writer = pyfstat.Writer(self.label, outdir=self.outdir, + Writer = pyfstat.Writer(self.label, outdir=outdir, duration=duration, dtglitch=dtglitch, delta_F0=delta_F0) @@ -137,9 +144,8 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase): self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3) -class TestMCMCSearch(unittest.TestCase): - label = "MCMCTest" - outdir = 'TestData' +class TestMCMCSearch(Test): + label = "Test" def test_fully_coherent(self): h0 = 1e-24 @@ -157,7 +163,7 @@ class TestMCMCSearch(unittest.TestCase): delta_F0 = 0 Writer = pyfstat.Writer(F0=F0, F1=F1, F2=F2, label=self.label, h0=h0, sqrtSX=sqrtSX, - outdir=self.outdir, tstart=tstart, + outdir=outdir, tstart=tstart, Alpha=Alpha, Delta=Delta, tref=tref, duration=duration, dtglitch=dtglitch, delta_F0=delta_F0, Band=4) @@ -170,8 +176,8 @@ class TestMCMCSearch(unittest.TestCase): 'F2': F2, 'Alpha': Alpha, 'Delta': Delta} search = pyfstat.MCMCSearch( - label=self.label, outdir=self.outdir, theta_prior=theta, tref=tref, - sftlabel=self.label, sftdir=self.outdir, + label=self.label, outdir=outdir, theta_prior=theta, tref=tref, + sftlabel=self.label, sftdir=outdir, tstart=tstart, tend=tend, nsteps=[100, 100], nwalkers=100, ntemps=1) search.run() @@ -185,4 +191,10 @@ class TestMCMCSearch(unittest.TestCase): if __name__ == '__main__': + outdir = 'TestData' + if os.path.isdir(outdir): + shutil.rmtree(outdir) unittest.main() + if os.path.isdir(outdir): + shutil.rmtree(outdir) +