Commit 3ac12eb2 authored by David Keitel's avatar David Keitel

tCWmap init: only do CUDA stuff if requested

parent 1c349df9
...@@ -660,7 +660,7 @@ class ComputeFstat(BaseSearchClass): ...@@ -660,7 +660,7 @@ class ComputeFstat(BaseSearchClass):
if self.dtau: if self.dtau:
self.windowRange.dtau = self.dtau self.windowRange.dtau = self.dtau
self.tCWFstatMapFeatures, self.gpu_context = tcw.init_transient_fstat_map_features(self.cudaDeviceName) self.tCWFstatMapFeatures, self.gpu_context = tcw.init_transient_fstat_map_features ( self.tCWFstatMapVersion=='pycuda', self.cudaDeviceName )
def get_fullycoherent_twoF(self, tstart, tend, F0, F1, F2, Alpha, Delta, def get_fullycoherent_twoF(self, tstart, tend, F0, F1, F2, Alpha, Delta,
asini=None, period=None, ecc=None, tp=None, asini=None, period=None, ecc=None, tp=None,
......
...@@ -89,7 +89,7 @@ fstatmap_versions = { ...@@ -89,7 +89,7 @@ fstatmap_versions = {
} }
def init_transient_fstat_map_features ( cudaDeviceName ): def init_transient_fstat_map_features ( wantCuda=False, cudaDeviceName=None ):
''' '''
Initialization of available modules (or "features") for F-stat maps. Initialization of available modules (or "features") for F-stat maps.
...@@ -115,7 +115,7 @@ def init_transient_fstat_map_features ( cudaDeviceName ): ...@@ -115,7 +115,7 @@ def init_transient_fstat_map_features ( cudaDeviceName ):
logging.debug('Got the following features for transient F-stat maps:') logging.debug('Got the following features for transient F-stat maps:')
logging.debug(features) logging.debug(features)
if features['pycuda']: if wantCuda and features['pycuda']:
logging.debug('CUDA version: {}'.format(drv.get_version())) logging.debug('CUDA version: {}'.format(drv.get_version()))
drv.init() drv.init()
...@@ -138,6 +138,7 @@ def init_transient_fstat_map_features ( cudaDeviceName ): ...@@ -138,6 +138,7 @@ def init_transient_fstat_map_features ( cudaDeviceName ):
else: else:
devnum0 = 0 devnum0 = 0
matchbit = ''
if cudaDeviceName: if cudaDeviceName:
# allow partial matches in device names # allow partial matches in device names
devmatches = [devidx for devidx, devname in enumerate(devnames) if cudaDeviceName in devname] devmatches = [devidx for devidx, devname in enumerate(devnames) if cudaDeviceName in devname]
...@@ -149,12 +150,13 @@ def init_transient_fstat_map_features ( cudaDeviceName ): ...@@ -149,12 +150,13 @@ def init_transient_fstat_map_features ( cudaDeviceName ):
if len(devmatches) > 1: if len(devmatches) > 1:
logging.warning('Found {} CUDA devices matching name "{}". Choosing first one with index {}.'.format(len(devmatches),cudaDeviceName,devnum)) logging.warning('Found {} CUDA devices matching name "{}". Choosing first one with index {}.'.format(len(devmatches),cudaDeviceName,devnum))
os.environ['CUDA_DEVICE'] = str(devnum) os.environ['CUDA_DEVICE'] = str(devnum)
matchbit = '(matched to user request "{}")'.format(cudaDeviceName)
elif 'CUDA_DEVICE' in os.environ: elif 'CUDA_DEVICE' in os.environ:
devnum = int(os.environ['CUDA_DEVICE']) devnum = int(os.environ['CUDA_DEVICE'])
else: else:
devnum = 0 devnum = 0
devn = devices[devnum] devn = devices[devnum]
logging.info('Choosing CUDA device {}, of {} devices present: {} (matched to user request "{}")...'.format(devnum,num_gpus,devn.name(),cudaDeviceName)) logging.info('Choosing CUDA device {}, of {} devices present: {}{}...'.format(devnum,num_gpus,devn.name(),matchbit))
if devnum == devnum0: if devnum == devnum0:
gpu_context = context0 gpu_context = context0
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment