From 0738b06dff1d60660b6ef6e84de7a54a65617f0a Mon Sep 17 00:00:00 2001
From: David Keitel <david.keitel@ligo.org>
Date: Fri, 19 Jan 2018 18:24:06 +0000
Subject: [PATCH] GPU device selection: through cudaDeviceName user option

 -requires manual context initialisation and cleanup
---
 pyfstat/core.py                | 15 +++++++++--
 pyfstat/grid_based_searches.py | 11 ++++++--
 pyfstat/tcw_fstat_map_funcs.py | 48 +++++++++++++++++++++++++++-------
 3 files changed, 60 insertions(+), 14 deletions(-)

diff --git a/pyfstat/core.py b/pyfstat/core.py
index abf1856..6a62498 100755
--- a/pyfstat/core.py
+++ b/pyfstat/core.py
@@ -337,7 +337,7 @@ class ComputeFstat(BaseSearchClass):
                  detectors=None, minCoverFreq=None, maxCoverFreq=None,
                  injectSources=None, injectSqrtSX=None, assumeSqrtSX=None,
                  SSBprec=None,
-                 tCWFstatMapVersion='lal'):
+                 tCWFstatMapVersion='lal', cudaDeviceName=None):
         """
         Parameters
         ----------
@@ -388,6 +388,8 @@ class ComputeFstat(BaseSearchClass):
         tCWFstatMapVersion: str
             Choose between standard 'lal' implementation,
             'pycuda' for gpu, and some others for devel/debug.
+        cudaDeviceName: str
+            GPU name to be matched against drv.Device output.
 
         """
 
@@ -658,7 +660,7 @@ class ComputeFstat(BaseSearchClass):
                     if self.dtau:
                         self.windowRange.dtau = self.dtau
 
-            self.tCWFstatMapFeatures = tcw.init_transient_fstat_map_features()
+            self.tCWFstatMapFeatures, self.gpu_context = tcw.init_transient_fstat_map_features(self.cudaDeviceName)
 
     def get_fullycoherent_twoF(self, tstart, tend, F0, F1, F2, Alpha, Delta,
                                asini=None, period=None, ecc=None, tp=None,
@@ -939,6 +941,15 @@ class ComputeFstat(BaseSearchClass):
             raise RuntimeError('Cannot print atoms vector to file: no FstatResults.multiFatoms, or it is None!')
 
 
+    def __del__(self):
+        """
+        In pyCuda case without autoinit,
+        we need to make sure the context is removed at the end
+        """
+        if hasattr(self,'gpu_context') and self.gpu_context:
+            self.gpu_context.detach()
+
+
 class SemiCoherentSearch(ComputeFstat):
     """ A semi-coherent search """
 
diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py
index 29d590f..1dc24f2 100644
--- a/pyfstat/grid_based_searches.py
+++ b/pyfstat/grid_based_searches.py
@@ -356,7 +356,7 @@ class TransientGridSearch(GridSearch):
                  dt0=None, dtau=None,
                  outputTransientFstatMap=False,
                  outputAtoms=False,
-                 tCWFstatMapVersion='lal'):
+                 tCWFstatMapVersion='lal', cudaDeviceName=None):
         """
         Parameters
         ----------
@@ -392,6 +392,8 @@ class TransientGridSearch(GridSearch):
         tCWFstatMapVersion: str
             Choose between standard 'lal' implementation,
             'pycuda' for gpu, and some others for devel/debug.
+        cudaDeviceName: str
+            GPU name to be matched against drv.Device output.
 
         For all other parameters, see `pyfstat.ComputeFStat` for details
         """
@@ -418,7 +420,8 @@ class TransientGridSearch(GridSearch):
             BSGL=self.BSGL, SSBprec=self.SSBprec,
             injectSources=self.injectSources,
             assumeSqrtSX=self.assumeSqrtSX,
-            tCWFstatMapVersion=self.tCWFstatMapVersion)
+            tCWFstatMapVersion=self.tCWFstatMapVersion,
+            cudaDeviceName=self.cudaDeviceName)
         self.search.get_det_stat = self.search.get_fullycoherent_twoF
 
     def run(self, return_data=False):
@@ -473,6 +476,10 @@ class TransientGridSearch(GridSearch):
                     this_tau = windowRange.tau + n * windowRange.dtau;
                     tfp.write('  %10d %10d %- 11.8g\n' % (this_t0, this_tau, 2.0*this_F))
 
+    def __del__(self):
+        if hasattr(self,'search'):
+            self.search.__del__()
+
 
 class SliceGridSearch(GridSearch):
     """ Slice gridded search using ComputeFstat """
diff --git a/pyfstat/tcw_fstat_map_funcs.py b/pyfstat/tcw_fstat_map_funcs.py
index b054d20..d149a6b 100644
--- a/pyfstat/tcw_fstat_map_funcs.py
+++ b/pyfstat/tcw_fstat_map_funcs.py
@@ -89,7 +89,7 @@ fstatmap_versions = {
                     }
 
 
-def init_transient_fstat_map_features ( ):
+def init_transient_fstat_map_features ( cudaDeviceName ):
     '''
     Initialization of available modules (or "features") for F-stat maps.
 
@@ -106,12 +106,11 @@ def init_transient_fstat_map_features ( ):
 
     # import GPU features
     have_pycuda          = optional_import('pycuda')
-    have_pycuda_init     = optional_import('pycuda.autoinit', 'autoinit')
     have_pycuda_drv      = optional_import('pycuda.driver', 'drv')
     have_pycuda_gpuarray = optional_import('pycuda.gpuarray', 'gpuarray')
     have_pycuda_tools    = optional_import('pycuda.tools', 'cudatools')
     have_pycuda_compiler = optional_import('pycuda.compiler', 'cudacomp')
-    features['pycuda']   = have_pycuda_drv and have_pycuda_init and have_pycuda_gpuarray and have_pycuda_tools and have_pycuda_compiler
+    features['pycuda']   = have_pycuda_drv and have_pycuda_gpuarray and have_pycuda_tools and have_pycuda_compiler
 
     logging.debug('Got the following features for transient F-stat maps:')
     logging.debug(features)
@@ -119,25 +118,54 @@ def init_transient_fstat_map_features ( ):
     if features['pycuda']:
         logging.debug('CUDA version: {}'.format(drv.get_version()))
 
+        drv.init()
+        logging.debug('Starting with default context, then checking all available devices...')
+        context0 = pycuda.tools.make_default_context()
+
         num_gpus = drv.Device.count()
         logging.debug('Found {} CUDA device(s).'.format(num_gpus))
 
         devices = []
+        devnames = np.empty(num_gpus,dtype='S32')
         for n in range(num_gpus):
-            devices.append(drv.Device(n))
-
-        for n, devn in enumerate(devices):
-            logging.debug('device {} model: {}, RAM: {}MB'.format(n,devn.name(),devn.total_memory()/(2.**20) ))
+            devn = drv.Device(n)
+            devices.append(devn)
+            devnames[n] = devn.name().replace(' ','-').replace('_','-')
+            logging.debug('device {}: model: {}, RAM: {}MB'.format(n,devnames[n],devn.total_memory()/(2.**20) ))
 
         if 'CUDA_DEVICE' in os.environ:
+            devnum0 = int(os.environ['CUDA_DEVICE'])
+        else:
+            devnum0 = 0
+
+        if cudaDeviceName:
+            devmatches = np.where(devnames == cudaDeviceName)[0]
+            if len(devmatches) == 0:
+                context0.detach()
+                raise RuntimeError('Requested CUDA device "{}" not found. Available devices: [{}]'.format(cudaDeviceName,','.join(devnames)))
+            else:
+                devnum = devmatches[0]
+                if len(devmatches) > 1:
+                    logging.warning('Found {} CUDA devices matching name "{}". Choosing first one with index {}.'.format(len(devmatches),cudaDeviceName,devnum))
+            os.environ['CUDA_DEVICE'] = str(devnum)
+        elif 'CUDA_DEVICE' in os.environ:
             devnum = int(os.environ['CUDA_DEVICE'])
         else:
             devnum = 0
-        devn = drv.Device(devnum)
-        logging.info('Choosing CUDA device {}, of {} devices present: {}... (Can be changed through environment variable $CUDA_DEVICE.)'.format(devnum,num_gpus,devn.name()))
+        devn = devices[devnum]
+        logging.info('Choosing CUDA device {}, of {} devices present: {} (matched to user request "{}")...'.format(devnum,num_gpus,devn.name(),devnames[devnum]))
+        if devnum == devnum0:
+            gpu_context = context0
+        else:
+            context0.pop()
+            gpu_context = pycuda.tools.make_default_context()
+            gpu_context.push()
+
         logging.debug('Available GPU memory: {}/{} MB free'.format(drv.mem_get_info()[0]/(2.**20),drv.mem_get_info()[1]/(2.**20)))
+    else:
+        gpu_context = None
 
-    return features
+    return features, gpu_context
 
 
 def call_compute_transient_fstat_map ( version, features, multiFstatAtoms=None, windowRange=None ):
-- 
GitLab