diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py index c82d79065706b495d7ebdc331d8e74b7990e8ed2..d231961df1465f14e231c479f78b305695a1a50b 100644 --- a/pyfstat/grid_based_searches.py +++ b/pyfstat/grid_based_searches.py @@ -87,17 +87,14 @@ class GridSearch(BaseSearchClass): def get_input_data_array(self): logging.info("Generating input data array") - arrays = [] + coord_arrays = [] for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas): - arrays.append(self.get_array_from_tuple(tup)) + coord_arrays.append(self.get_array_from_tuple(tup)) - input_data = [] - for vals in itertools.product(*arrays): - input_data.append(vals) - - self.arrays = arrays - self.input_data = np.array(input_data) + self.input_data_generator_len = np.prod([len(k) for k in coord_arrays]) + self.input_data_generator = itertools.product(*coord_arrays) + self.coord_arrays = coord_arrays def check_old_data_is_okay_to_use(self): if args.clean: @@ -112,15 +109,18 @@ class GridSearch(BaseSearchClass): logging.info('Search output data outdates sft files,' + ' continuing with grid search') return False - data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' ')) - if np.all(data[:, 0:-1] == self.input_data): - logging.info( - 'Old data found with matching input, no search performed') - return data - else: - logging.info( - 'Old data found, input differs, continuing with grid search') - return False + + logging.info('No data caching available') + return False + #data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' ')) + #if np.all(data[:, 0:-1] == self.input_data): + # logging.info( + # 'Old data found with matching input, no search performed') + # return data + #else: + # logging.info( + # 'Old data found, input differs, continuing with grid search') + # return False def run(self, return_data=False): self.get_input_data_array() @@ -131,11 +131,8 @@ class GridSearch(BaseSearchClass): self.inititate_search_object() - logging.info('Total number of grid points is {}'.format( - len(self.input_data))) - data = [] - for vals in tqdm(self.input_data): + for vals in tqdm(self.input_data_generator): FS = self.search.get_det_stat(*vals) data.append(list(vals) + [FS]) @@ -330,9 +327,9 @@ class SliceGridSearch(GridSearch): self.get_input_data_array() self.Lambda0s_grid = [] - for j in range(self.input_data.shape[1]): - i = np.argmin(np.abs(self.Lambda0[j]-self.input_data[:, j])) - self.Lambda0s_grid.append(self.input_data[:, j][i]) + for j, arr in enumerate(self.coord_arrays): + i = np.argmin(np.abs(self.Lambda0[j]-arr)) + self.Lambda0s_grid.append(arr[i]) old_data = self.check_old_data_is_okay_to_use() if old_data is not False: @@ -342,11 +339,12 @@ class SliceGridSearch(GridSearch): self.inititate_search_object() logging.info('Total number of grid points is {}'.format( - len(self.input_data))) + self.input_data_generator_len)) data = [] - for vals in tqdm(self.input_data): - if np.sum(vals != self.Lambda0s_grid) < 3: + for vals in tqdm(self.input_data_generator, + total=self.input_data_generator_len): + if np.sum(np.array(vals) != np.array(self.Lambda0s_grid)) < 3: FS = self.search.get_det_stat(*vals) data.append(list(vals) + [FS]) else: