diff --git a/pyfstat/tcw_fstat_map_funcs.py b/pyfstat/tcw_fstat_map_funcs.py
index d149a6bfc4c02db3a882189431eca108867a346f..c318c9100e5bf8bea905226909c460d3c54edcb1 100644
--- a/pyfstat/tcw_fstat_map_funcs.py
+++ b/pyfstat/tcw_fstat_map_funcs.py
@@ -139,7 +139,8 @@ def init_transient_fstat_map_features ( cudaDeviceName ):
             devnum0 = 0
 
         if cudaDeviceName:
-            devmatches = np.where(devnames == cudaDeviceName)[0]
+            # allow partial matches in device names
+            devmatches = [devidx for devidx, devname in enumerate(devnames) if cudaDeviceName in devname]
             if len(devmatches) == 0:
                 context0.detach()
                 raise RuntimeError('Requested CUDA device "{}" not found. Available devices: [{}]'.format(cudaDeviceName,','.join(devnames)))
@@ -153,7 +154,7 @@ def init_transient_fstat_map_features ( cudaDeviceName ):
         else:
             devnum = 0
         devn = devices[devnum]
-        logging.info('Choosing CUDA device {}, of {} devices present: {} (matched to user request "{}")...'.format(devnum,num_gpus,devn.name(),devnames[devnum]))
+        logging.info('Choosing CUDA device {}, of {} devices present: {} (matched to user request "{}")...'.format(devnum,num_gpus,devn.name(),cudaDeviceName))
         if devnum == devnum0:
             gpu_context = context0
         else: