Skip to content
Snippets Groups Projects
Commit b85bdda0 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improvements to the tests and fixes

parent cd13d3f6
Branches
Tags
No related merge requests found
...@@ -254,7 +254,7 @@ class GridSearch(BaseSearchClass): ...@@ -254,7 +254,7 @@ class GridSearch(BaseSearchClass):
fig.tight_layout() fig.tight_layout()
fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
else: else:
return fig, ax return ax
def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None,
add_mismatch=None, xN=None, yN=None, flat_keys=[], add_mismatch=None, xN=None, yN=None, flat_keys=[],
...@@ -409,10 +409,12 @@ class SliceGridSearch(GridSearch): ...@@ -409,10 +409,12 @@ class SliceGridSearch(GridSearch):
self.ndim = 4 self.ndim = 4
self.search_keys = ['F0', 'F1', 'Alpha', 'Delta'] self.search_keys = ['F0', 'F1', 'Alpha', 'Delta']
self.Lambda0 = np.array(Lambda0) if self.Lambda0 is None:
raise ValueError('Lambda0 undefined')
if len(self.Lambda0) != len(self.search_keys): if len(self.Lambda0) != len(self.search_keys):
raise ValueError( raise ValueError(
'Lambda0 must be of length {}'.format(len(self.search_keys))) 'Lambda0 must be of length {}'.format(len(self.search_keys)))
self.Lambda0 = np.array(Lambda0)
def run(self, factor=2, max_n_ticks=4, whspace=0.07, save=True, def run(self, factor=2, max_n_ticks=4, whspace=0.07, save=True,
**kwargs): **kwargs):
...@@ -613,6 +615,10 @@ class FrequencySlidingWindow(GridSearch): ...@@ -613,6 +615,10 @@ class FrequencySlidingWindow(GridSearch):
For all other parameters, see `pyfstat.ComputeFStat` for details For all other parameters, see `pyfstat.ComputeFStat` for details
""" """
self.transientWindowType = None
self.t0Band = None
self.tauBand = None
if os.path.isdir(outdir) is False: if os.path.isdir(outdir) is False:
os.mkdir(outdir) os.mkdir(outdir)
self.set_out_file() self.set_out_file()
...@@ -622,6 +628,7 @@ class FrequencySlidingWindow(GridSearch): ...@@ -622,6 +628,7 @@ class FrequencySlidingWindow(GridSearch):
self.Alphas = [Alpha] self.Alphas = [Alpha]
self.Deltas = [Delta] self.Deltas = [Delta]
self.input_arrays = False self.input_arrays = False
self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
def inititate_search_object(self): def inititate_search_object(self):
logging.info('Setting up search object') logging.info('Setting up search object')
......
...@@ -13,6 +13,27 @@ class Test(unittest.TestCase): ...@@ -13,6 +13,27 @@ class Test(unittest.TestCase):
def setUpClass(self): def setUpClass(self):
if os.path.isdir(self.outdir): if os.path.isdir(self.outdir):
shutil.rmtree(self.outdir) shutil.rmtree(self.outdir)
h0 = 1
sqrtSX = 1
F0 = 30
F1 = -1e-10
F2 = 0
minStartTime = 700000000
duration = 2 * 86400
Alpha = 5e-3
Delta = 1.2
tref = minStartTime
Writer = pyfstat.Writer(F0=F0, F1=F1, F2=F2, label='test',
h0=h0, sqrtSX=sqrtSX,
outdir=self.outdir, tstart=minStartTime,
Alpha=Alpha, Delta=Delta, tref=tref,
duration=duration,
Band=4)
Writer.make_data()
self.sftfilepath = Writer.sftfilepath
self.minStartTime = minStartTime
self.maxStartTime = minStartTime + duration
self.duration = duration
@classmethod @classmethod
def tearDownClass(self): def tearDownClass(self):
...@@ -20,7 +41,7 @@ class Test(unittest.TestCase): ...@@ -20,7 +41,7 @@ class Test(unittest.TestCase):
shutil.rmtree(self.outdir) shutil.rmtree(self.outdir)
class TestWriter(Test): class Writer(Test):
label = "TestWriter" label = "TestWriter"
def test_make_cff(self): def test_make_cff(self):
...@@ -53,13 +74,13 @@ class TestWriter(Test): ...@@ -53,13 +74,13 @@ class TestWriter(Test):
self.assertFalse(time_first == time_third) self.assertFalse(time_first == time_third)
class TestBunch(Test): class Bunch(Test):
def test_bunch(self): def test_bunch(self):
b = pyfstat.core.Bunch(dict(x=10)) b = pyfstat.core.Bunch(dict(x=10))
self.assertTrue(b.x == 10) self.assertTrue(b.x == 10)
class Test_par(Test): class par(Test):
label = 'TestPar' label = 'TestPar'
def test(self): def test(self):
...@@ -79,7 +100,7 @@ class Test_par(Test): ...@@ -79,7 +100,7 @@ class Test_par(Test):
os.system('rm -r {}'.format(self.outdir)) os.system('rm -r {}'.format(self.outdir))
class TestBaseSearchClass(Test): class BaseSearchClass(Test):
def test_shift_matrix(self): def test_shift_matrix(self):
BSC = pyfstat.BaseSearchClass() BSC = pyfstat.BaseSearchClass()
dT = 10 dT = 10
...@@ -118,7 +139,7 @@ class TestBaseSearchClass(Test): ...@@ -118,7 +139,7 @@ class TestBaseSearchClass(Test):
rtol=1e-9, atol=1e-9)) rtol=1e-9, atol=1e-9))
class TestComputeFstat(Test): class ComputeFstat(Test):
label = "TestComputeFstat" label = "TestComputeFstat"
def test_run_computefstatistic_single_point(self): def test_run_computefstatistic_single_point(self):
...@@ -196,7 +217,7 @@ class TestComputeFstat(Test): ...@@ -196,7 +217,7 @@ class TestComputeFstat(Test):
self.assertTrue(FS_from_dict == FS_from_file) self.assertTrue(FS_from_dict == FS_from_file)
class TestSemiCoherentSearch(Test): class SemiCoherentSearch(Test):
label = "TestSemiCoherentSearch" label = "TestSemiCoherentSearch"
def test_get_semicoherent_twoF(self): def test_get_semicoherent_twoF(self):
...@@ -249,7 +270,7 @@ class TestSemiCoherentSearch(Test): ...@@ -249,7 +270,7 @@ class TestSemiCoherentSearch(Test):
self.assertTrue(BSGL > 0) self.assertTrue(BSGL > 0)
class TestSemiCoherentGlitchSearch(Test): class SemiCoherentGlitchSearch(Test):
label = "TestSemiCoherentGlitchSearch" label = "TestSemiCoherentGlitchSearch"
def test_get_semicoherent_nglitch_twoF(self): def test_get_semicoherent_nglitch_twoF(self):
...@@ -292,7 +313,7 @@ class TestSemiCoherentGlitchSearch(Test): ...@@ -292,7 +313,7 @@ class TestSemiCoherentGlitchSearch(Test):
self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3) self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3)
class TestMCMCSearch(Test): class MCMCSearch(Test):
label = "TestMCMCSearch" label = "TestMCMCSearch"
def test_fully_coherent(self): def test_fully_coherent(self):
...@@ -335,5 +356,49 @@ class TestMCMCSearch(Test): ...@@ -335,5 +356,49 @@ class TestMCMCSearch(Test):
FS > predicted_FS or np.abs((FS-predicted_FS))/predicted_FS < 0.3) FS > predicted_FS or np.abs((FS-predicted_FS))/predicted_FS < 0.3)
class GridSearch(Test):
F0s = [29, 31, 0.1]
F1s = [-1e-10, 0, 1e-11]
tref = 700000000
def test_grid_search(self):
search = pyfstat.GridSearch(
'grid_search', self.outdir, self.sftfilepath, F0s=self.F0s,
F1s=[0], F2s=[0], Alphas=[0], Deltas=[0], tref=self.tref)
search.run()
self.assertTrue(os.path.isfile(search.out_file))
def test_semicoherent_grid_search(self):
search = pyfstat.GridSearch(
'sc_grid_search', self.outdir, self.sftfilepath, F0s=self.F0s,
F1s=[0], F2s=[0], Alphas=[0], Deltas=[0], tref=self.tref, nsegs=2)
search.run()
self.assertTrue(os.path.isfile(search.out_file))
def test_slice_grid_search(self):
search = pyfstat.SliceGridSearch(
'slice_grid_search', self.outdir, self.sftfilepath, F0s=self.F0s,
F1s=self.F1s, F2s=[0], Alphas=[0], Deltas=[0], tref=self.tref,
Lambda0=[30, 0, 0, 0])
search.run()
self.assertTrue(os.path.isfile('{}/{}_slice_projection.png'
.format(search.outdir, search.label)))
def test_glitch_grid_search(self):
search = pyfstat.GridGlitchSearch(
'grid_grid_search', self.outdir, self.sftfilepath, F0s=self.F0s,
F1s=self.F1s, F2s=[0], Alphas=[0], Deltas=[0], tref=self.tref,
tglitchs=[self.tref])
search.run()
self.assertTrue(os.path.isfile(search.out_file))
def test_sliding_window(self):
search = pyfstat.FrequencySlidingWindow(
'grid_grid_search', self.outdir, self.sftfilepath, F0s=self.F0s,
F1=0, F2=0, Alpha=0, Delta=0, tref=self.tref,
minStartTime=self.minStartTime, maxStartTime=self.maxStartTime)
search.run()
self.assertTrue(os.path.isfile(search.out_file))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment