diff --git a/pyfstat/tcw_fstat_map_funcs.py b/pyfstat/tcw_fstat_map_funcs.py index d149a6bfc4c02db3a882189431eca108867a346f..c318c9100e5bf8bea905226909c460d3c54edcb1 100644 --- a/pyfstat/tcw_fstat_map_funcs.py +++ b/pyfstat/tcw_fstat_map_funcs.py @@ -139,7 +139,8 @@ def init_transient_fstat_map_features ( cudaDeviceName ): devnum0 = 0 if cudaDeviceName: - devmatches = np.where(devnames == cudaDeviceName)[0] + # allow partial matches in device names + devmatches = [devidx for devidx, devname in enumerate(devnames) if cudaDeviceName in devname] if len(devmatches) == 0: context0.detach() raise RuntimeError('Requested CUDA device "{}" not found. Available devices: [{}]'.format(cudaDeviceName,','.join(devnames))) @@ -153,7 +154,7 @@ def init_transient_fstat_map_features ( cudaDeviceName ): else: devnum = 0 devn = devices[devnum] - logging.info('Choosing CUDA device {}, of {} devices present: {} (matched to user request "{}")...'.format(devnum,num_gpus,devn.name(),devnames[devnum])) + logging.info('Choosing CUDA device {}, of {} devices present: {} (matched to user request "{}")...'.format(devnum,num_gpus,devn.name(),cudaDeviceName)) if devnum == devnum0: gpu_context = context0 else: