Skip to content
Snippets Groups Projects
Commit f93e27b1 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Allow use of generator in grid based searches

This updates get_input_data_array. First, the list of possible search
parameters is defined as a list of strings in the class itself avoiding
repeated definitions of get_input_data_array. Second, is `args.clean == True`
then the input_data is not generated as an array, but a generator. For
large arrays this avoids memory issues.
parent 00b0e770
No related branches found
No related tags found
No related merge requests found
...@@ -28,6 +28,8 @@ class GridSearch(BaseSearchClass): ...@@ -28,6 +28,8 @@ class GridSearch(BaseSearchClass):
'Alpha': r'$\alpha$', 'Delta': r'$\delta$'} 'Alpha': r'$\alpha$', 'Delta': r'$\delta$'}
tex_labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$', 'F2': '$-\ddot{f}_0$', tex_labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$', 'F2': '$-\ddot{f}_0$',
'Alpha': r'$-\alpha_0$', 'Delta': r'$-\delta_0$'} 'Alpha': r'$-\alpha_0$', 'Delta': r'$-\delta_0$'}
search_labels = ['minStartTime', 'maxStartTime', 'F0s', 'F1s', 'F2s',
'Alphas', 'Deltas']
@helper_functions.initializer @helper_functions.initializer
def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas, def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
...@@ -68,6 +70,10 @@ class GridSearch(BaseSearchClass): ...@@ -68,6 +70,10 @@ class GridSearch(BaseSearchClass):
(one file for each doppler grid point!) (one file for each doppler grid point!)
For all other parameters, see `pyfstat.ComputeFStat` for details For all other parameters, see `pyfstat.ComputeFStat` for details
Note: if a large number of grid points are used, checks against cached
data may be slow as the array is loaded into memory. To avoid this, run
with the `clean` option which uses a generator instead.
""" """
if os.path.isdir(outdir) is False: if os.path.isdir(outdir) is False:
...@@ -117,15 +123,17 @@ class GridSearch(BaseSearchClass): ...@@ -117,15 +123,17 @@ class GridSearch(BaseSearchClass):
def get_input_data_array(self): def get_input_data_array(self):
logging.info("Generating input data array") logging.info("Generating input data array")
coord_arrays = [] coord_arrays = []
for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, for sl in self.search_labels:
self.F1s, self.F2s, self.Alphas, self.Deltas): coord_arrays.append(
coord_arrays.append(self.get_array_from_tuple(tup)) self.get_array_from_tuple(np.atleast_1d(getattr(self, sl))))
self.coord_arrays = coord_arrays
self.total_iterations = np.prod([len(ca) for ca in coord_arrays])
if args.clean is False:
input_data = [] input_data = []
for vals in itertools.product(*coord_arrays): for vals in itertools.product(*coord_arrays):
input_data.append(vals) input_data.append(vals)
self.input_data = np.array(input_data) self.input_data = np.array(input_data)
self.coord_arrays = coord_arrays
def check_old_data_is_okay_to_use(self): def check_old_data_is_okay_to_use(self):
if args.clean: if args.clean:
...@@ -159,7 +167,13 @@ class GridSearch(BaseSearchClass): ...@@ -159,7 +167,13 @@ class GridSearch(BaseSearchClass):
def run(self, return_data=False): def run(self, return_data=False):
self.get_input_data_array() self.get_input_data_array()
if args.clean:
iterable = itertools.product(*self.coord_arrays)
else:
old_data = self.check_old_data_is_okay_to_use() old_data = self.check_old_data_is_okay_to_use()
iterable = self.input_data
if old_data is not False: if old_data is not False:
self.data = old_data self.data = old_data
return return
...@@ -168,7 +182,8 @@ class GridSearch(BaseSearchClass): ...@@ -168,7 +182,8 @@ class GridSearch(BaseSearchClass):
self.inititate_search_object() self.inititate_search_object()
data = [] data = []
for vals in tqdm(self.input_data): for vals in tqdm(iterable,
total=getattr(self, 'total_iterations', None)):
detstat = self.search.get_det_stat(*vals) detstat = self.search.get_det_stat(*vals)
windowRange = getattr(self.search, 'windowRange', None) windowRange = getattr(self.search, 'windowRange', None)
FstatMap = getattr(self.search, 'FstatMap', None) FstatMap = getattr(self.search, 'FstatMap', None)
...@@ -531,6 +546,9 @@ class GridUniformPriorSearch(): ...@@ -531,6 +546,9 @@ class GridUniformPriorSearch():
class GridGlitchSearch(GridSearch): class GridGlitchSearch(GridSearch):
""" Grid search using the SemiCoherentGlitchSearch """ """ Grid search using the SemiCoherentGlitchSearch """
search_labels = ['F0s', 'F1s', 'F2s', 'Alphas', 'Deltas', 'delta_F0s',
'delta_F1s', 'tglitchs']
@helper_functions.initializer @helper_functions.initializer
def __init__(self, label, outdir='data', sftfilepattern=None, F0s=[0], def __init__(self, label, outdir='data', sftfilepattern=None, F0s=[0],
F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None, F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None,
...@@ -575,19 +593,6 @@ class GridGlitchSearch(GridSearch): ...@@ -575,19 +593,6 @@ class GridGlitchSearch(GridSearch):
self.keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta', 'delta_F0', self.keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta', 'delta_F0',
'delta_F1', 'tglitch'] 'delta_F1', 'tglitch']
def get_input_data_array(self):
logging.info("Generating input data array")
coord_arrays = []
for tup in (self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas,
self.delta_F0s, self.delta_F1s, self.tglitchs):
coord_arrays.append(self.get_array_from_tuple(tup))
input_data = []
for vals in itertools.product(*coord_arrays):
input_data.append(vals)
self.input_data = np.array(input_data)
self.coord_arrays = coord_arrays
class FrequencySlidingWindow(GridSearch): class FrequencySlidingWindow(GridSearch):
""" A sliding-window search over the Frequency """ """ A sliding-window search over the Frequency """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment