diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py index f9e269921b51dc02262c7799ae86c388dcae786c..46e2cf8dc5d4bddb81398394fa8c60859f9f478f 100644 --- a/pyfstat/grid_based_searches.py +++ b/pyfstat/grid_based_searches.py @@ -254,7 +254,7 @@ class GridSearch(BaseSearchClass): fig.tight_layout() fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) else: - return fig, ax + return ax def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None, add_mismatch=None, xN=None, yN=None, flat_keys=[], @@ -409,10 +409,12 @@ class SliceGridSearch(GridSearch): self.ndim = 4 self.search_keys = ['F0', 'F1', 'Alpha', 'Delta'] - self.Lambda0 = np.array(Lambda0) + if self.Lambda0 is None: + raise ValueError('Lambda0 undefined') if len(self.Lambda0) != len(self.search_keys): raise ValueError( 'Lambda0 must be of length {}'.format(len(self.search_keys))) + self.Lambda0 = np.array(Lambda0) def run(self, factor=2, max_n_ticks=4, whspace=0.07, save=True, **kwargs): @@ -613,6 +615,10 @@ class FrequencySlidingWindow(GridSearch): For all other parameters, see `pyfstat.ComputeFStat` for details """ + self.transientWindowType = None + self.t0Band = None + self.tauBand = None + if os.path.isdir(outdir) is False: os.mkdir(outdir) self.set_out_file() @@ -622,6 +628,7 @@ class FrequencySlidingWindow(GridSearch): self.Alphas = [Alpha] self.Deltas = [Delta] self.input_arrays = False + self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta'] def inititate_search_object(self): logging.info('Setting up search object') diff --git a/tests.py b/tests.py index 50b4c2fed8b5044cb60063349157e7300400499d..4b853e4fd0d61247fe68a75c481ffc935a542ba8 100644 --- a/tests.py +++ b/tests.py @@ -13,6 +13,27 @@ class Test(unittest.TestCase): def setUpClass(self): if os.path.isdir(self.outdir): shutil.rmtree(self.outdir) + h0 = 1 + sqrtSX = 1 + F0 = 30 + F1 = -1e-10 + F2 = 0 + minStartTime = 700000000 + duration = 2 * 86400 + Alpha = 5e-3 + Delta = 1.2 + tref = minStartTime + Writer = pyfstat.Writer(F0=F0, F1=F1, F2=F2, label='test', + h0=h0, sqrtSX=sqrtSX, + outdir=self.outdir, tstart=minStartTime, + Alpha=Alpha, Delta=Delta, tref=tref, + duration=duration, + Band=4) + Writer.make_data() + self.sftfilepath = Writer.sftfilepath + self.minStartTime = minStartTime + self.maxStartTime = minStartTime + duration + self.duration = duration @classmethod def tearDownClass(self): @@ -20,7 +41,7 @@ class Test(unittest.TestCase): shutil.rmtree(self.outdir) -class TestWriter(Test): +class Writer(Test): label = "TestWriter" def test_make_cff(self): @@ -53,13 +74,13 @@ class TestWriter(Test): self.assertFalse(time_first == time_third) -class TestBunch(Test): +class Bunch(Test): def test_bunch(self): b = pyfstat.core.Bunch(dict(x=10)) self.assertTrue(b.x == 10) -class Test_par(Test): +class par(Test): label = 'TestPar' def test(self): @@ -79,7 +100,7 @@ class Test_par(Test): os.system('rm -r {}'.format(self.outdir)) -class TestBaseSearchClass(Test): +class BaseSearchClass(Test): def test_shift_matrix(self): BSC = pyfstat.BaseSearchClass() dT = 10 @@ -118,7 +139,7 @@ class TestBaseSearchClass(Test): rtol=1e-9, atol=1e-9)) -class TestComputeFstat(Test): +class ComputeFstat(Test): label = "TestComputeFstat" def test_run_computefstatistic_single_point(self): @@ -196,7 +217,7 @@ class TestComputeFstat(Test): self.assertTrue(FS_from_dict == FS_from_file) -class TestSemiCoherentSearch(Test): +class SemiCoherentSearch(Test): label = "TestSemiCoherentSearch" def test_get_semicoherent_twoF(self): @@ -249,7 +270,7 @@ class TestSemiCoherentSearch(Test): self.assertTrue(BSGL > 0) -class TestSemiCoherentGlitchSearch(Test): +class SemiCoherentGlitchSearch(Test): label = "TestSemiCoherentGlitchSearch" def test_get_semicoherent_nglitch_twoF(self): @@ -292,7 +313,7 @@ class TestSemiCoherentGlitchSearch(Test): self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3) -class TestMCMCSearch(Test): +class MCMCSearch(Test): label = "TestMCMCSearch" def test_fully_coherent(self): @@ -335,5 +356,49 @@ class TestMCMCSearch(Test): FS > predicted_FS or np.abs((FS-predicted_FS))/predicted_FS < 0.3) +class GridSearch(Test): + F0s = [29, 31, 0.1] + F1s = [-1e-10, 0, 1e-11] + tref = 700000000 + + def test_grid_search(self): + search = pyfstat.GridSearch( + 'grid_search', self.outdir, self.sftfilepath, F0s=self.F0s, + F1s=[0], F2s=[0], Alphas=[0], Deltas=[0], tref=self.tref) + search.run() + self.assertTrue(os.path.isfile(search.out_file)) + + def test_semicoherent_grid_search(self): + search = pyfstat.GridSearch( + 'sc_grid_search', self.outdir, self.sftfilepath, F0s=self.F0s, + F1s=[0], F2s=[0], Alphas=[0], Deltas=[0], tref=self.tref, nsegs=2) + search.run() + self.assertTrue(os.path.isfile(search.out_file)) + + def test_slice_grid_search(self): + search = pyfstat.SliceGridSearch( + 'slice_grid_search', self.outdir, self.sftfilepath, F0s=self.F0s, + F1s=self.F1s, F2s=[0], Alphas=[0], Deltas=[0], tref=self.tref, + Lambda0=[30, 0, 0, 0]) + search.run() + self.assertTrue(os.path.isfile('{}/{}_slice_projection.png' + .format(search.outdir, search.label))) + + def test_glitch_grid_search(self): + search = pyfstat.GridGlitchSearch( + 'grid_grid_search', self.outdir, self.sftfilepath, F0s=self.F0s, + F1s=self.F1s, F2s=[0], Alphas=[0], Deltas=[0], tref=self.tref, + tglitchs=[self.tref]) + search.run() + self.assertTrue(os.path.isfile(search.out_file)) + + def test_sliding_window(self): + search = pyfstat.FrequencySlidingWindow( + 'grid_grid_search', self.outdir, self.sftfilepath, F0s=self.F0s, + F1=0, F2=0, Alpha=0, Delta=0, tref=self.tref, + minStartTime=self.minStartTime, maxStartTime=self.maxStartTime) + search.run() + self.assertTrue(os.path.isfile(search.out_file)) + if __name__ == '__main__': unittest.main()