diff --git a/pyfstat/tcw_fstat_map_funcs.py b/pyfstat/tcw_fstat_map_funcs.py
index fd43b289dd660a1ea2a13e6cfb8ac46b92f5eb7c..09b9e28e656694fc9472f329669638682aff5d52 100644
--- a/pyfstat/tcw_fstat_map_funcs.py
+++ b/pyfstat/tcw_fstat_map_funcs.py
@@ -2,6 +2,7 @@
 
 import numpy as np
 import os
+import sys
 import logging
 
 # optional imports
@@ -14,22 +15,46 @@ def optional_import ( modulename, shorthand=None ):
 
     using importlib instead of __import__
     because the latter doesn't handle sub.modules
+
+    Also including a special check to fail more gracefully
+    when CUDA_DEVICE is set to too high a number.
     '''
+
     if shorthand is None:
         shorthand    = modulename
         shorthandbit = ''
     else:
         shorthandbit = ' as '+shorthand
-    try:
-        globals()[shorthand] = imp.import_module(modulename)
-        #logging.debug('Successfully imported module %s%s.' % (modulename, shorthandbit))
-        success = True
-    except ImportError, e:
-        if e.message == 'No module named '+modulename:
-            logging.warning('No module {:s} found.'.format(modulename))
-            success = False
-        else:
-            raise
+
+    if('pycuda' in sys.modules):
+        try:
+            globals()[shorthand] = imp.import_module(modulename)
+            logging.debug('Successfully imported module %s%s.' % (modulename, shorthandbit))
+            success = True
+        except pycuda._driver.LogicError, e:
+            if e.message == 'cuDeviceGet failed: invalid device ordinal':
+                devn = int(os.environ['CUDA_DEVICE'])
+                raise RuntimeError('Requested CUDA device number {} exceeds number of available devices! Please change through environment variable $CUDA_DEVICE.'.format(devn))
+            else:
+                raise pycuda._driver.LogicError(e.message)
+        except ImportError, e:
+            if e.message == 'No module named '+modulename:
+                logging.debug('No module {:s} found.'.format(modulename))
+                success = False
+            else:
+                raise
+    else:
+        try:
+            globals()[shorthand] = imp.import_module(modulename)
+            logging.debug('Successfully imported module %s%s.' % (modulename, shorthandbit))
+            success = True
+        except ImportError, e:
+            if e.message == 'No module named '+modulename:
+                logging.debug('No module {:s} found.'.format(modulename))
+                success = False
+            else:
+                raise
+
     return success
 
 
@@ -80,8 +105,9 @@ def init_transient_fstat_map_features ( ):
     features['lal']    = have_lal and have_lalpulsar
 
     # import GPU features
-    have_pycuda_drv      = optional_import('pycuda.driver', 'drv')
+    have_pycuda          = optional_import('pycuda')
     have_pycuda_init     = optional_import('pycuda.autoinit', 'autoinit')
+    have_pycuda_drv      = optional_import('pycuda.driver', 'drv')
     have_pycuda_gpuarray = optional_import('pycuda.gpuarray', 'gpuarray')
     have_pycuda_tools    = optional_import('pycuda.tools', 'cudatools')
     have_pycuda_compiler = optional_import('pycuda.compiler', 'cudacomp')
@@ -90,6 +116,24 @@ def init_transient_fstat_map_features ( ):
     logging.debug('Got the following features for transient F-stat maps:')
     logging.debug(features)
 
+    if features['pycuda']:
+        logging.debug('CUDA version: {}'.format(drv.get_version()))
+
+        num_gpus = drv.Device.count()
+        logging.debug('Found {} CUDA device(s).'.format(num_gpus))
+
+        devices = []
+        for n in range(num_gpus):
+            devices.append(drv.Device(n))
+
+        for n, devn in enumerate(devices):
+            logging.debug('device {} model: {}, RAM: {}MB'.format(n,devn.name(),devn.total_memory()/(2.**20) ))
+
+        devnum = int(os.environ['CUDA_DEVICE'])
+        devn = drv.Device(devnum)
+        logging.info('Choosing CUDA device {}, of {} devices present: {}... (Can be changed through environment variable $CUDA_DEVICE.)'.format(devnum,num_gpus,devn.name()))
+        logging.debug('Available GPU memory: {}/{} MB free'.format(drv.mem_get_info()[0]/(2.**20),drv.mem_get_info()[1]/(2.**20)))
+
     return features