Skip to content
Snippets Groups Projects
Commit 6580d08a authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Update to generator rather than saving input_data in memory

parent 608b4b94
No related branches found
No related tags found
No related merge requests found
......@@ -87,17 +87,14 @@ class GridSearch(BaseSearchClass):
def get_input_data_array(self):
logging.info("Generating input data array")
arrays = []
coord_arrays = []
for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s,
self.Alphas, self.Deltas):
arrays.append(self.get_array_from_tuple(tup))
input_data = []
for vals in itertools.product(*arrays):
input_data.append(vals)
coord_arrays.append(self.get_array_from_tuple(tup))
self.arrays = arrays
self.input_data = np.array(input_data)
self.input_data_generator_len = np.prod([len(k) for k in coord_arrays])
self.input_data_generator = itertools.product(*coord_arrays)
self.coord_arrays = coord_arrays
def check_old_data_is_okay_to_use(self):
if args.clean:
......@@ -112,15 +109,18 @@ class GridSearch(BaseSearchClass):
logging.info('Search output data outdates sft files,'
+ ' continuing with grid search')
return False
data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
if np.all(data[:, 0:-1] == self.input_data):
logging.info(
'Old data found with matching input, no search performed')
return data
else:
logging.info(
'Old data found, input differs, continuing with grid search')
logging.info('No data caching available')
return False
#data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
#if np.all(data[:, 0:-1] == self.input_data):
# logging.info(
# 'Old data found with matching input, no search performed')
# return data
#else:
# logging.info(
# 'Old data found, input differs, continuing with grid search')
# return False
def run(self, return_data=False):
self.get_input_data_array()
......@@ -131,11 +131,8 @@ class GridSearch(BaseSearchClass):
self.inititate_search_object()
logging.info('Total number of grid points is {}'.format(
len(self.input_data)))
data = []
for vals in tqdm(self.input_data):
for vals in tqdm(self.input_data_generator):
FS = self.search.get_det_stat(*vals)
data.append(list(vals) + [FS])
......@@ -330,9 +327,9 @@ class SliceGridSearch(GridSearch):
self.get_input_data_array()
self.Lambda0s_grid = []
for j in range(self.input_data.shape[1]):
i = np.argmin(np.abs(self.Lambda0[j]-self.input_data[:, j]))
self.Lambda0s_grid.append(self.input_data[:, j][i])
for j, arr in enumerate(self.coord_arrays):
i = np.argmin(np.abs(self.Lambda0[j]-arr))
self.Lambda0s_grid.append(arr[i])
old_data = self.check_old_data_is_okay_to_use()
if old_data is not False:
......@@ -342,11 +339,12 @@ class SliceGridSearch(GridSearch):
self.inititate_search_object()
logging.info('Total number of grid points is {}'.format(
len(self.input_data)))
self.input_data_generator_len))
data = []
for vals in tqdm(self.input_data):
if np.sum(vals != self.Lambda0s_grid) < 3:
for vals in tqdm(self.input_data_generator,
total=self.input_data_generator_len):
if np.sum(np.array(vals) != np.array(self.Lambda0s_grid)) < 3:
FS = self.search.get_det_stat(*vals)
data.append(list(vals) + [FS])
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment