Skip to content
Snippets Groups Projects
Commit 3ac12eb2 authored by David Keitel's avatar David Keitel
Browse files

tCWmap init: only do CUDA stuff if requested

parent 1c349df9
No related branches found
No related tags found
No related merge requests found
......@@ -660,7 +660,7 @@ class ComputeFstat(BaseSearchClass):
if self.dtau:
self.windowRange.dtau = self.dtau
self.tCWFstatMapFeatures, self.gpu_context = tcw.init_transient_fstat_map_features(self.cudaDeviceName)
self.tCWFstatMapFeatures, self.gpu_context = tcw.init_transient_fstat_map_features ( self.tCWFstatMapVersion=='pycuda', self.cudaDeviceName )
def get_fullycoherent_twoF(self, tstart, tend, F0, F1, F2, Alpha, Delta,
asini=None, period=None, ecc=None, tp=None,
......
......@@ -89,7 +89,7 @@ fstatmap_versions = {
}
def init_transient_fstat_map_features ( cudaDeviceName ):
def init_transient_fstat_map_features ( wantCuda=False, cudaDeviceName=None ):
'''
Initialization of available modules (or "features") for F-stat maps.
......@@ -115,7 +115,7 @@ def init_transient_fstat_map_features ( cudaDeviceName ):
logging.debug('Got the following features for transient F-stat maps:')
logging.debug(features)
if features['pycuda']:
if wantCuda and features['pycuda']:
logging.debug('CUDA version: {}'.format(drv.get_version()))
drv.init()
......@@ -138,6 +138,7 @@ def init_transient_fstat_map_features ( cudaDeviceName ):
else:
devnum0 = 0
matchbit = ''
if cudaDeviceName:
# allow partial matches in device names
devmatches = [devidx for devidx, devname in enumerate(devnames) if cudaDeviceName in devname]
......@@ -149,12 +150,13 @@ def init_transient_fstat_map_features ( cudaDeviceName ):
if len(devmatches) > 1:
logging.warning('Found {} CUDA devices matching name "{}". Choosing first one with index {}.'.format(len(devmatches),cudaDeviceName,devnum))
os.environ['CUDA_DEVICE'] = str(devnum)
matchbit = '(matched to user request "{}")'.format(cudaDeviceName)
elif 'CUDA_DEVICE' in os.environ:
devnum = int(os.environ['CUDA_DEVICE'])
else:
devnum = 0
devn = devices[devnum]
logging.info('Choosing CUDA device {}, of {} devices present: {} (matched to user request "{}")...'.format(devnum,num_gpus,devn.name(),cudaDeviceName))
logging.info('Choosing CUDA device {}, of {} devices present: {}{}...'.format(devnum,num_gpus,devn.name(),matchbit))
if devnum == devnum0:
gpu_context = context0
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment