From 7e71b9d2bc746e93724a9055277a3c495a0ae942 Mon Sep 17 00:00:00 2001 From: David Keitel <david.keitel@ligo.org> Date: Thu, 8 Feb 2018 12:03:56 +0000 Subject: [PATCH] pyCUDA initialisation: fix check for $CUDA_DEVICE exceeding device count -previously check was done in _optional_import [called from init_transient_fstat_map_features()] -this was the right place only back when still importing autoinit -now do check at the time of make_default_context() --- pyfstat/tcw_fstat_map_funcs.py | 56 +++++++++++++--------------------- 1 file changed, 22 insertions(+), 34 deletions(-) diff --git a/pyfstat/tcw_fstat_map_funcs.py b/pyfstat/tcw_fstat_map_funcs.py index f84d71d..4237fa9 100644 --- a/pyfstat/tcw_fstat_map_funcs.py +++ b/pyfstat/tcw_fstat_map_funcs.py @@ -26,39 +26,17 @@ def _optional_import ( modulename, shorthand=None ): else: shorthandbit = ' as '+shorthand - if('pycuda' in sys.modules): - try: - globals()[shorthand] = imp.import_module(modulename) - logging.debug('Successfully imported module %s%s.' - % (modulename, shorthandbit)) - success = True - except pycuda._driver.LogicError, e: - if e.message == 'cuDeviceGet failed: invalid device ordinal': - devn = int(os.environ['CUDA_DEVICE']) - raise RuntimeError('Requested CUDA device number {} exceeds' \ - ' number of available devices!' \ - ' Please change through environment' \ - ' variable $CUDA_DEVICE.'.format(devn)) - else: - raise pycuda._driver.LogicError(e.message) - except ImportError, e: - if e.message == 'No module named '+modulename: - logging.debug('No module {:s} found.'.format(modulename)) - success = False - else: - raise - else: - try: - globals()[shorthand] = imp.import_module(modulename) - logging.debug('Successfully imported module %s%s.' - % (modulename, shorthandbit)) - success = True - except ImportError, e: - if e.message == 'No module named '+modulename: - logging.debug('No module {:s} found.'.format(modulename)) - success = False - else: - raise + try: + globals()[shorthand] = imp.import_module(modulename) + logging.debug('Successfully imported module %s%s.' + % (modulename, shorthandbit)) + success = True + except ImportError, e: + if e.message == 'No module named '+modulename: + logging.debug('No module {:s} found.'.format(modulename)) + success = False + else: + raise return success @@ -131,7 +109,17 @@ def init_transient_fstat_map_features ( wantCuda=False, cudaDeviceName=None ): drv.init() logging.debug('Starting with default pyCUDA context,' \ ' then checking all available devices...') - context0 = pycuda.tools.make_default_context() + try: + context0 = pycuda.tools.make_default_context() + except pycuda._driver.LogicError, e: + if e.message == 'cuDeviceGet failed: invalid device ordinal': + devn = int(os.environ['CUDA_DEVICE']) + raise RuntimeError('Requested CUDA device number {} exceeds' \ + ' number of available devices!' \ + ' Please change through environment' \ + ' variable $CUDA_DEVICE.'.format(devn)) + else: + raise pycuda._driver.LogicError(e.message) num_gpus = drv.Device.count() logging.debug('Found {} CUDA device(s).'.format(num_gpus)) -- GitLab