diff --git a/pyfstat/tcw_fstat_map_funcs.py b/pyfstat/tcw_fstat_map_funcs.py index 09b9e28e656694fc9472f329669638682aff5d52..b054d20ff01978d92023732ae77b72488f422769 100644 --- a/pyfstat/tcw_fstat_map_funcs.py +++ b/pyfstat/tcw_fstat_map_funcs.py @@ -129,7 +129,10 @@ def init_transient_fstat_map_features ( ): for n, devn in enumerate(devices): logging.debug('device {} model: {}, RAM: {}MB'.format(n,devn.name(),devn.total_memory()/(2.**20) )) - devnum = int(os.environ['CUDA_DEVICE']) + if 'CUDA_DEVICE' in os.environ: + devnum = int(os.environ['CUDA_DEVICE']) + else: + devnum = 0 devn = drv.Device(devnum) logging.info('Choosing CUDA device {}, of {} devices present: {}... (Can be changed through environment variable $CUDA_DEVICE.)'.format(devnum,num_gpus,devn.name())) logging.debug('Available GPU memory: {}/{} MB free'.format(drv.mem_get_info()[0]/(2.**20),drv.mem_get_info()[1]/(2.**20)))