Commit 6580d08a authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Update to generator rather than saving input_data in memory

parent 608b4b94
...@@ -87,17 +87,14 @@ class GridSearch(BaseSearchClass): ...@@ -87,17 +87,14 @@ class GridSearch(BaseSearchClass):
def get_input_data_array(self): def get_input_data_array(self):
logging.info("Generating input data array") logging.info("Generating input data array")
arrays = [] coord_arrays = []
for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s, for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s,
self.Alphas, self.Deltas): self.Alphas, self.Deltas):
arrays.append(self.get_array_from_tuple(tup)) coord_arrays.append(self.get_array_from_tuple(tup))
input_data = [] self.input_data_generator_len = np.prod([len(k) for k in coord_arrays])
for vals in itertools.product(*arrays): self.input_data_generator = itertools.product(*coord_arrays)
input_data.append(vals) self.coord_arrays = coord_arrays
self.arrays = arrays
self.input_data = np.array(input_data)
def check_old_data_is_okay_to_use(self): def check_old_data_is_okay_to_use(self):
if args.clean: if args.clean:
...@@ -112,15 +109,18 @@ class GridSearch(BaseSearchClass): ...@@ -112,15 +109,18 @@ class GridSearch(BaseSearchClass):
logging.info('Search output data outdates sft files,' logging.info('Search output data outdates sft files,'
+ ' continuing with grid search') + ' continuing with grid search')
return False return False
data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
if np.all(data[:, 0:-1] == self.input_data): logging.info('No data caching available')
logging.info( return False
'Old data found with matching input, no search performed') #data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
return data #if np.all(data[:, 0:-1] == self.input_data):
else: # logging.info(
logging.info( # 'Old data found with matching input, no search performed')
'Old data found, input differs, continuing with grid search') # return data
return False #else:
# logging.info(
# 'Old data found, input differs, continuing with grid search')
# return False
def run(self, return_data=False): def run(self, return_data=False):
self.get_input_data_array() self.get_input_data_array()
...@@ -131,11 +131,8 @@ class GridSearch(BaseSearchClass): ...@@ -131,11 +131,8 @@ class GridSearch(BaseSearchClass):
self.inititate_search_object() self.inititate_search_object()
logging.info('Total number of grid points is {}'.format(
len(self.input_data)))
data = [] data = []
for vals in tqdm(self.input_data): for vals in tqdm(self.input_data_generator):
FS = self.search.get_det_stat(*vals) FS = self.search.get_det_stat(*vals)
data.append(list(vals) + [FS]) data.append(list(vals) + [FS])
...@@ -330,9 +327,9 @@ class SliceGridSearch(GridSearch): ...@@ -330,9 +327,9 @@ class SliceGridSearch(GridSearch):
self.get_input_data_array() self.get_input_data_array()
self.Lambda0s_grid = [] self.Lambda0s_grid = []
for j in range(self.input_data.shape[1]): for j, arr in enumerate(self.coord_arrays):
i = np.argmin(np.abs(self.Lambda0[j]-self.input_data[:, j])) i = np.argmin(np.abs(self.Lambda0[j]-arr))
self.Lambda0s_grid.append(self.input_data[:, j][i]) self.Lambda0s_grid.append(arr[i])
old_data = self.check_old_data_is_okay_to_use() old_data = self.check_old_data_is_okay_to_use()
if old_data is not False: if old_data is not False:
...@@ -342,11 +339,12 @@ class SliceGridSearch(GridSearch): ...@@ -342,11 +339,12 @@ class SliceGridSearch(GridSearch):
self.inititate_search_object() self.inititate_search_object()
logging.info('Total number of grid points is {}'.format( logging.info('Total number of grid points is {}'.format(
len(self.input_data))) self.input_data_generator_len))
data = [] data = []
for vals in tqdm(self.input_data): for vals in tqdm(self.input_data_generator,
if np.sum(vals != self.Lambda0s_grid) < 3: total=self.input_data_generator_len):
if np.sum(np.array(vals) != np.array(self.Lambda0s_grid)) < 3:
FS = self.search.get_det_stat(*vals) FS = self.search.get_det_stat(*vals)
data.append(list(vals) + [FS]) data.append(list(vals) + [FS])
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment