diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py index 46e2cf8dc5d4bddb81398394fa8c60859f9f478f..de51d0fa03936a72ca2d3fdbe1a852671b49ef0f 100644 --- a/pyfstat/grid_based_searches.py +++ b/pyfstat/grid_based_searches.py @@ -28,6 +28,8 @@ class GridSearch(BaseSearchClass): 'Alpha': r'$\alpha$', 'Delta': r'$\delta$'} tex_labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$', 'F2': '$-\ddot{f}_0$', 'Alpha': r'$-\alpha_0$', 'Delta': r'$-\delta_0$'} + search_labels = ['minStartTime', 'maxStartTime', 'F0s', 'F1s', 'F2s', + 'Alphas', 'Deltas'] @helper_functions.initializer def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas, @@ -68,6 +70,10 @@ class GridSearch(BaseSearchClass): (one file for each doppler grid point!) For all other parameters, see `pyfstat.ComputeFStat` for details + + Note: if a large number of grid points are used, checks against cached + data may be slow as the array is loaded into memory. To avoid this, run + with the `clean` option which uses a generator instead. """ if os.path.isdir(outdir) is False: @@ -117,15 +123,17 @@ class GridSearch(BaseSearchClass): def get_input_data_array(self): logging.info("Generating input data array") coord_arrays = [] - for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, - self.F1s, self.F2s, self.Alphas, self.Deltas): - coord_arrays.append(self.get_array_from_tuple(tup)) - - input_data = [] - for vals in itertools.product(*coord_arrays): - input_data.append(vals) - self.input_data = np.array(input_data) + for sl in self.search_labels: + coord_arrays.append( + self.get_array_from_tuple(np.atleast_1d(getattr(self, sl)))) self.coord_arrays = coord_arrays + self.total_iterations = np.prod([len(ca) for ca in coord_arrays]) + + if args.clean is False: + input_data = [] + for vals in itertools.product(*coord_arrays): + input_data.append(vals) + self.input_data = np.array(input_data) def check_old_data_is_okay_to_use(self): if args.clean: @@ -159,16 +167,23 @@ class GridSearch(BaseSearchClass): def run(self, return_data=False): self.get_input_data_array() - old_data = self.check_old_data_is_okay_to_use() - if old_data is not False: - self.data = old_data - return + + if args.clean: + iterable = itertools.product(*self.coord_arrays) + else: + old_data = self.check_old_data_is_okay_to_use() + iterable = self.input_data + + if old_data is not False: + self.data = old_data + return if hasattr(self, 'search') is False: self.inititate_search_object() data = [] - for vals in tqdm(self.input_data): + for vals in tqdm(iterable, + total=getattr(self, 'total_iterations', None)): detstat = self.search.get_det_stat(*vals) windowRange = getattr(self.search, 'windowRange', None) FstatMap = getattr(self.search, 'FstatMap', None) @@ -531,6 +546,9 @@ class GridUniformPriorSearch(): class GridGlitchSearch(GridSearch): """ Grid search using the SemiCoherentGlitchSearch """ + search_labels = ['F0s', 'F1s', 'F2s', 'Alphas', 'Deltas', 'delta_F0s', + 'delta_F1s', 'tglitchs'] + @helper_functions.initializer def __init__(self, label, outdir='data', sftfilepattern=None, F0s=[0], F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None, @@ -575,19 +593,6 @@ class GridGlitchSearch(GridSearch): self.keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta', 'delta_F0', 'delta_F1', 'tglitch'] - def get_input_data_array(self): - logging.info("Generating input data array") - coord_arrays = [] - for tup in (self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas, - self.delta_F0s, self.delta_F1s, self.tglitchs): - coord_arrays.append(self.get_array_from_tuple(tup)) - - input_data = [] - for vals in itertools.product(*coord_arrays): - input_data.append(vals) - self.input_data = np.array(input_data) - self.coord_arrays = coord_arrays - class FrequencySlidingWindow(GridSearch): """ A sliding-window search over the Frequency """