Skip to content
Snippets Groups Projects
Commit 7e71b9d2 authored by David Keitel's avatar David Keitel
Browse files

pyCUDA initialisation: fix check for $CUDA_DEVICE exceeding device count

 -previously check was done in _optional_import
  [called from init_transient_fstat_map_features()]
 -this was the right place only back when still importing autoinit
 -now do check at the time of make_default_context()
parent 6de7b834
Branches
Tags
No related merge requests found
......@@ -26,28 +26,6 @@ def _optional_import ( modulename, shorthand=None ):
else:
shorthandbit = ' as '+shorthand
if('pycuda' in sys.modules):
try:
globals()[shorthand] = imp.import_module(modulename)
logging.debug('Successfully imported module %s%s.'
% (modulename, shorthandbit))
success = True
except pycuda._driver.LogicError, e:
if e.message == 'cuDeviceGet failed: invalid device ordinal':
devn = int(os.environ['CUDA_DEVICE'])
raise RuntimeError('Requested CUDA device number {} exceeds' \
' number of available devices!' \
' Please change through environment' \
' variable $CUDA_DEVICE.'.format(devn))
else:
raise pycuda._driver.LogicError(e.message)
except ImportError, e:
if e.message == 'No module named '+modulename:
logging.debug('No module {:s} found.'.format(modulename))
success = False
else:
raise
else:
try:
globals()[shorthand] = imp.import_module(modulename)
logging.debug('Successfully imported module %s%s.'
......@@ -131,7 +109,17 @@ def init_transient_fstat_map_features ( wantCuda=False, cudaDeviceName=None ):
drv.init()
logging.debug('Starting with default pyCUDA context,' \
' then checking all available devices...')
try:
context0 = pycuda.tools.make_default_context()
except pycuda._driver.LogicError, e:
if e.message == 'cuDeviceGet failed: invalid device ordinal':
devn = int(os.environ['CUDA_DEVICE'])
raise RuntimeError('Requested CUDA device number {} exceeds' \
' number of available devices!' \
' Please change through environment' \
' variable $CUDA_DEVICE.'.format(devn))
else:
raise pycuda._driver.LogicError(e.message)
num_gpus = drv.Device.count()
logging.debug('Found {} CUDA device(s).'.format(num_gpus))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment