Skip to content
Snippets Groups Projects
Commit 80200e9f authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Fixes broken tests

parent 52dc771d
Branches
Tags
No related merge requests found
...@@ -828,7 +828,8 @@ class MCMCSearch(BaseSearchClass): ...@@ -828,7 +828,8 @@ class MCMCSearch(BaseSearchClass):
def get_save_data_dictionary(self): def get_save_data_dictionary(self):
d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers, d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
ntemps=self.ntemps, theta_keys=self.theta_keys, ntemps=self.ntemps, theta_keys=self.theta_keys,
theta_prior=self.theta_prior, scatter_val=self.scatter_val) theta_prior=self.theta_prior, scatter_val=self.scatter_val,
log10temperature_min=self.log10temperature_min)
return d return d
def save_data(self, sampler, samples, lnprobs, lnlikes): def save_data(self, sampler, samples, lnprobs, lnlikes):
......
import unittest import unittest
import pyfstat
import numpy as np import numpy as np
import os import os
import shutil
import pyfstat
class TestWriter(unittest.TestCase): class Test(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
def test_make_cff(self): @classmethod
def tearDownClass(cls):
pass
class TestWriter(Test):
label = "Test" label = "Test"
Writer = pyfstat.Writer(label, outdir='TestData')
def test_make_cff(self):
Writer = pyfstat.Writer(self.label, outdir=outdir)
Writer.make_cff() Writer.make_cff()
self.assertTrue(os.path.isfile('./TestData/Test.cff')) self.assertTrue(os.path.isfile('./TestData/Test.cff'))
def test_run_makefakedata(self): def test_run_makefakedata(self):
label = "Test" Writer = pyfstat.Writer(self.label, outdir=outdir)
Writer = pyfstat.Writer(label, outdir='TestData')
Writer.make_cff() Writer.make_cff()
Writer.run_makefakedata() Writer.run_makefakedata()
self.assertTrue(os.path.isfile( self.assertTrue(os.path.isfile(
'./TestData/H-4800_H1_1800SFT_Test-700000000-8640000.sft')) './TestData/H-4800_H1_1800SFT_Test-700000000-8640000.sft'))
def test_makefakedata_usecached(self): def test_makefakedata_usecached(self):
label = "Test" Writer = pyfstat.Writer(self.label, outdir=outdir)
Writer = pyfstat.Writer(label, outdir='TestData')
if os.path.isfile(Writer.sft_filepath): if os.path.isfile(Writer.sft_filepath):
os.remove(Writer.sft_filepath) os.remove(Writer.sft_filepath)
Writer.run_makefakedata() Writer.run_makefakedata()
...@@ -36,7 +45,7 @@ class TestWriter(unittest.TestCase): ...@@ -36,7 +45,7 @@ class TestWriter(unittest.TestCase):
self.assertFalse(time_first == time_third) self.assertFalse(time_first == time_third)
class TestBaseSearchClass(unittest.TestCase): class TestBaseSearchClass(Test):
def test_shift_matrix(self): def test_shift_matrix(self):
BSC = pyfstat.BaseSearchClass() BSC = pyfstat.BaseSearchClass()
dT = 10 dT = 10
...@@ -75,12 +84,11 @@ class TestBaseSearchClass(unittest.TestCase): ...@@ -75,12 +84,11 @@ class TestBaseSearchClass(unittest.TestCase):
rtol=1e-9, atol=1e-9)) rtol=1e-9, atol=1e-9))
class TestComputeFstat(unittest.TestCase): class TestComputeFstat(Test):
label = "Test" label = "Test"
outdir = 'TestData'
def test_run_computefstatistic_single_point(self): def test_run_computefstatistic_single_point(self):
Writer = pyfstat.Writer(self.label, outdir=self.outdir) Writer = pyfstat.Writer(self.label, outdir=outdir)
Writer.make_data() Writer.make_data()
predicted_FS = Writer.predict_fstat() predicted_FS = Writer.predict_fstat()
...@@ -97,15 +105,14 @@ class TestComputeFstat(unittest.TestCase): ...@@ -97,15 +105,14 @@ class TestComputeFstat(unittest.TestCase):
self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.1) self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.1)
class TestSemiCoherentGlitchSearch(unittest.TestCase): class TestSemiCoherentGlitchSearch(Test):
label = "Test" label = "Test"
outdir = 'TestData'
def test_compute_nglitch_fstat(self): def test_compute_nglitch_fstat(self):
duration = 100*86400 duration = 100*86400
dtglitch = 100*43200 dtglitch = 100*43200
delta_F0 = 0 delta_F0 = 0
Writer = pyfstat.Writer(self.label, outdir=self.outdir, Writer = pyfstat.Writer(self.label, outdir=outdir,
duration=duration, dtglitch=dtglitch, duration=duration, dtglitch=dtglitch,
delta_F0=delta_F0) delta_F0=delta_F0)
...@@ -137,9 +144,8 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase): ...@@ -137,9 +144,8 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase):
self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3) self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3)
class TestMCMCSearch(unittest.TestCase): class TestMCMCSearch(Test):
label = "MCMCTest" label = "Test"
outdir = 'TestData'
def test_fully_coherent(self): def test_fully_coherent(self):
h0 = 1e-24 h0 = 1e-24
...@@ -157,7 +163,7 @@ class TestMCMCSearch(unittest.TestCase): ...@@ -157,7 +163,7 @@ class TestMCMCSearch(unittest.TestCase):
delta_F0 = 0 delta_F0 = 0
Writer = pyfstat.Writer(F0=F0, F1=F1, F2=F2, label=self.label, Writer = pyfstat.Writer(F0=F0, F1=F1, F2=F2, label=self.label,
h0=h0, sqrtSX=sqrtSX, h0=h0, sqrtSX=sqrtSX,
outdir=self.outdir, tstart=tstart, outdir=outdir, tstart=tstart,
Alpha=Alpha, Delta=Delta, tref=tref, Alpha=Alpha, Delta=Delta, tref=tref,
duration=duration, dtglitch=dtglitch, duration=duration, dtglitch=dtglitch,
delta_F0=delta_F0, Band=4) delta_F0=delta_F0, Band=4)
...@@ -170,8 +176,8 @@ class TestMCMCSearch(unittest.TestCase): ...@@ -170,8 +176,8 @@ class TestMCMCSearch(unittest.TestCase):
'F2': F2, 'Alpha': Alpha, 'Delta': Delta} 'F2': F2, 'Alpha': Alpha, 'Delta': Delta}
search = pyfstat.MCMCSearch( search = pyfstat.MCMCSearch(
label=self.label, outdir=self.outdir, theta_prior=theta, tref=tref, label=self.label, outdir=outdir, theta_prior=theta, tref=tref,
sftlabel=self.label, sftdir=self.outdir, sftlabel=self.label, sftdir=outdir,
tstart=tstart, tend=tend, nsteps=[100, 100], nwalkers=100, tstart=tstart, tend=tend, nsteps=[100, 100], nwalkers=100,
ntemps=1) ntemps=1)
search.run() search.run()
...@@ -185,4 +191,10 @@ class TestMCMCSearch(unittest.TestCase): ...@@ -185,4 +191,10 @@ class TestMCMCSearch(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
outdir = 'TestData'
if os.path.isdir(outdir):
shutil.rmtree(outdir)
unittest.main() unittest.main()
if os.path.isdir(outdir):
shutil.rmtree(outdir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment