diff --git a/examples/transient_examples/short_transient_search_gridded.py b/examples/transient_examples/short_transient_search_gridded.py
index dea9c26745e968bc6af825b61aaeb78b72fef22a..5011b8052ef325bc7b3590195b9112a335b1cdaf 100644
--- a/examples/transient_examples/short_transient_search_gridded.py
+++ b/examples/transient_examples/short_transient_search_gridded.py
@@ -50,7 +50,8 @@ search2 = pyfstat.TransientGridSearch(
     minStartTime=minStartTime, maxStartTime=maxStartTime,
     transientWindowType='rect', t0Band=Tspan-2*Tsft, tauBand=Tspan,
     BSGL=False,
-    outputTransientFstatMap=True)
+    outputTransientFstatMap=True,
+    tCWFstatMapVersion='lal')
 search2.run()
 search2.print_max_twoF()
 
diff --git a/pyfstat/core.py b/pyfstat/core.py
index 4482121b3c3b09d9fd3d90182389720010d00f44..50705e73f48d8ceb6c0ff0db3314a49c3426b7d2 100755
--- a/pyfstat/core.py
+++ b/pyfstat/core.py
@@ -13,6 +13,7 @@ import scipy.optimize
 import lal
 import lalpulsar
 import pyfstat.helper_functions as helper_functions
+import pyfstat.tcw_fstat_map_funcs as tcw
 
 # workaround for matplotlib on X-less remote logins
 if 'DISPLAY' in os.environ:
@@ -335,7 +336,8 @@ class ComputeFstat(BaseSearchClass):
                  dt0=None, dtau=None,
                  detectors=None, minCoverFreq=None, maxCoverFreq=None,
                  injectSources=None, injectSqrtSX=None, assumeSqrtSX=None,
-                 SSBprec=None):
+                 SSBprec=None,
+                 tCWFstatMapVersion='lal', cudaDeviceName=None):
         """
         Parameters
         ----------
@@ -383,6 +385,11 @@ class ComputeFstat(BaseSearchClass):
         SSBprec : int
             Flag to set the SSB calculation: 0=Newtonian, 1=relativistic,
             2=relativisitic optimised, 3=DMoff, 4=NO_SPIN
+        tCWFstatMapVersion: str
+            Choose between standard 'lal' implementation,
+            'pycuda' for gpu, and some others for devel/debug.
+        cudaDeviceName: str
+            GPU name to be matched against drv.Device output.
 
         """
 
@@ -625,13 +632,14 @@ class ComputeFstat(BaseSearchClass):
             self.windowRange.dt0 = self.Tsft
             self.windowRange.dtau = self.Tsft
 
-            # special treatment of window_type = none ==> replace by rectangular window spanning all the data
+            # special treatment of window_type = none
+            # ==> replace by rectangular window spanning all the data
             if self.windowRange.type == lalpulsar.TRANSIENT_NONE:
                 self.windowRange.t0 = int(self.minStartTime)
                 self.windowRange.t0Band = 0
                 self.windowRange.tau = int(self.maxStartTime-self.minStartTime)
                 self.windowRange.tauBand = 0
-            else: # user-set bands and spacings
+            else:  # user-set bands and spacings
                 if self.t0Band is None:
                     self.windowRange.t0Band = 0
                 else:
@@ -653,6 +661,11 @@ class ComputeFstat(BaseSearchClass):
                     if self.dtau:
                         self.windowRange.dtau = self.dtau
 
+            logging.info('Initialising transient FstatMap features...')
+            self.tCWFstatMapFeatures, self.gpu_context = (
+                tcw.init_transient_fstat_map_features(
+                    self.tCWFstatMapVersion == 'pycuda', self.cudaDeviceName))
+
     def get_fullycoherent_twoF(self, tstart, tend, F0, F1, F2, Alpha, Delta,
                                asini=None, period=None, ecc=None, tp=None,
                                argp=None):
@@ -695,9 +708,13 @@ class ComputeFstat(BaseSearchClass):
             # F-stat computation
             self.windowRange.tau = int(2*self.Tsft)
 
-        self.FstatMap = lalpulsar.ComputeTransientFstatMap(
-            self.FstatResults.multiFatoms[0], self.windowRange, False)
-        F_mn = self.FstatMap.F_mn.data
+        self.FstatMap = tcw.call_compute_transient_fstat_map(
+            self.tCWFstatMapVersion, self.tCWFstatMapFeatures,
+            self.FstatResults.multiFatoms[0], self.windowRange)
+        if self.tCWFstatMapVersion == 'lal':
+            F_mn = self.FstatMap.F_mn.data
+        else:
+            F_mn = self.FstatMap.F_mn
 
         twoF = 2*np.max(F_mn)
         if self.BSGL is False:
@@ -920,6 +937,15 @@ class ComputeFstat(BaseSearchClass):
             raise RuntimeError('Cannot print atoms vector to file: no FstatResults.multiFatoms, or it is None!')
 
 
+    def __del__(self):
+        """
+        In pyCuda case without autoinit,
+        we need to make sure the context is removed at the end
+        """
+        if hasattr(self,'gpu_context') and self.gpu_context:
+            self.gpu_context.detach()
+
+
 class SemiCoherentSearch(ComputeFstat):
     """ A semi-coherent search """
 
@@ -950,6 +976,8 @@ class SemiCoherentSearch(ComputeFstat):
         self.transientWindowType = 'rect'
         self.t0Band  = None
         self.tauBand = None
+        self.tCWFstatMapVersion = 'lal'
+        self.cudaDeviceName = None
         self.init_computefstatistic_single_point()
         self.init_semicoherent_parameters()
 
@@ -1089,6 +1117,8 @@ class SemiCoherentGlitchSearch(ComputeFstat):
         self.transientWindowType = 'rect'
         self.t0Band  = None
         self.tauBand = None
+        self.tCWFstatMapVersion = 'lal'
+        self.cudaDeviceName = None
         self.binary  = False
         self.init_computefstatistic_single_point()
 
diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py
index 42b3c3e9a6e67b63815d59cae2cce9af0870017b..8d830c291da203b1c05f8f8008f7cd96cb675c2e 100644
--- a/pyfstat/grid_based_searches.py
+++ b/pyfstat/grid_based_searches.py
@@ -355,7 +355,8 @@ class TransientGridSearch(GridSearch):
                  transientWindowType=None, t0Band=None, tauBand=None,
                  dt0=None, dtau=None,
                  outputTransientFstatMap=False,
-                 outputAtoms=False):
+                 outputAtoms=False,
+                 tCWFstatMapVersion='lal', cudaDeviceName=None):
         """
         Parameters
         ----------
@@ -388,6 +389,11 @@ class TransientGridSearch(GridSearch):
         outputTransientFstatMap: bool
             if true, write output files for (t0,tau) Fstat maps
             (one file for each doppler grid point!)
+        tCWFstatMapVersion: str
+            Choose between standard 'lal' implementation,
+            'pycuda' for gpu, and some others for devel/debug.
+        cudaDeviceName: str
+            GPU name to be matched against drv.Device output.
 
         For all other parameters, see `pyfstat.ComputeFStat` for details
         """
@@ -413,7 +419,9 @@ class TransientGridSearch(GridSearch):
             minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
             BSGL=self.BSGL, SSBprec=self.SSBprec,
             injectSources=self.injectSources,
-            assumeSqrtSX=self.assumeSqrtSX)
+            assumeSqrtSX=self.assumeSqrtSX,
+            tCWFstatMapVersion=self.tCWFstatMapVersion,
+            cudaDeviceName=self.cudaDeviceName)
         self.search.get_det_stat = self.search.get_fullycoherent_twoF
 
     def run(self, return_data=False):
@@ -427,19 +435,36 @@ class TransientGridSearch(GridSearch):
             self.inititate_search_object()
 
         data = []
+        if self.outputTransientFstatMap:
+            tCWfilebase = os.path.splitext(self.out_file)[0] + '_tCW_'
+            logging.info('Will save per-Doppler Fstatmap' \
+                         ' results to {}*.dat'.format(tCWfilebase))
         for vals in tqdm(self.input_data):
             detstat = self.search.get_det_stat(*vals)
             windowRange = getattr(self.search, 'windowRange', None)
             FstatMap = getattr(self.search, 'FstatMap', None)
             thisCand = list(vals) + [detstat]
             if getattr(self, 'transientWindowType', None):
+                if self.tCWFstatMapVersion == 'lal':
+                    F_mn = FstatMap.F_mn.data
+                else:
+                    F_mn = FstatMap.F_mn
                 if self.outputTransientFstatMap:
-                    tCWfile = os.path.splitext(self.out_file)[0]+'_tCW_%.16f_%.16f_%.16f_%.16g_%.16g.dat' % (vals[2],vals[5],vals[6],vals[3],vals[4]) # freq alpha delta f1dot f2dot
-                    fo = lal.FileOpen(tCWfile, 'w')
-                    lalpulsar.write_transientFstatMap_to_fp ( fo, FstatMap, windowRange, None )
-                    del fo # instead of lal.FileClose() which is not SWIG-exported
-                Fmn = FstatMap.F_mn.data
-                maxidx = np.unravel_index(Fmn.argmax(), Fmn.shape)
+                    # per-Doppler filename convention:
+                    # freq alpha delta f1dot f2dot
+                    tCWfile = ( tCWfilebase
+                                + '%.16f_%.16f_%.16f_%.16g_%.16g.dat' %
+                                (vals[2],vals[5],vals[6],vals[3],vals[4]) )
+                    if self.tCWFstatMapVersion == 'lal':
+                        fo = lal.FileOpen(tCWfile, 'w')
+                        lalpulsar.write_transientFstatMap_to_fp (
+                            fo, FstatMap, windowRange, None )
+                        # instead of lal.FileClose(),
+                        # which is not SWIG-exported:
+                        del fo
+                    else:
+                        self.write_F_mn ( tCWfile, F_mn, windowRange)
+                maxidx = np.unravel_index(F_mn.argmax(), F_mn.shape)
                 thisCand += [windowRange.t0+maxidx[0]*windowRange.dt0,
                              windowRange.tau+maxidx[1]*windowRange.dtau]
             data.append(thisCand)
@@ -453,6 +478,19 @@ class TransientGridSearch(GridSearch):
             self.save_array_to_disk(data)
             self.data = data
 
+    def write_F_mn (self, tCWfile, F_mn, windowRange ):
+        with open(tCWfile, 'w') as tfp:
+            tfp.write('# t0 [s]     tau [s]     2F\n')
+            for m, F_m in enumerate(F_mn):
+                this_t0 = windowRange.t0 + m * windowRange.dt0
+                for n, this_F in enumerate(F_m):
+                    this_tau = windowRange.tau + n * windowRange.dtau;
+                    tfp.write('  %10d %10d %- 11.8g\n' % (this_t0, this_tau, 2.0*this_F))
+
+    def __del__(self):
+        if hasattr(self,'search'):
+            self.search.__del__()
+
 
 class SliceGridSearch(GridSearch):
     """ Slice gridded search using ComputeFstat """
diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index b09c7d0397c7133862993bced37bed74eed085ae..99b28b66e5a3bc3d26b51c166abbed4fefc2b15c 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -82,6 +82,9 @@ class MCMCSearch(core.BaseSearchClass):
         ('none' instead of None explicitly calls the transient-window function,
         but with the full range, for debugging)
         Currently only supported for nsegs=1.
+    tCWFstatMapVersion: str
+        Choose between standard 'lal' implementation,
+        'pycuda' for gpu, and some others for devel/debug.
 
     Attributes
     ----------
@@ -115,7 +118,7 @@ class MCMCSearch(core.BaseSearchClass):
                  rhohatmax=1000, binary=False, BSGL=False,
                  SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
                  injectSources=None, assumeSqrtSX=None,
-                 transientWindowType=None):
+                 transientWindowType=None, tCWFstatMapVersion='lal'):
 
         if os.path.isdir(outdir) is False:
             os.mkdir(outdir)
@@ -161,7 +164,8 @@ class MCMCSearch(core.BaseSearchClass):
             transientWindowType=self.transientWindowType,
             minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
             binary=self.binary, injectSources=self.injectSources,
-            assumeSqrtSX=self.assumeSqrtSX, SSBprec=self.SSBprec)
+            assumeSqrtSX=self.assumeSqrtSX, SSBprec=self.SSBprec,
+            tCWFstatMapVersion=self.tCWFstatMapVersion)
         if self.minStartTime is None:
             self.minStartTime = self.search.minStartTime
         if self.maxStartTime is None:
@@ -2212,7 +2216,8 @@ class MCMCTransientSearch(MCMCSearch):
             transientWindowType=self.transientWindowType,
             minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
             BSGL=self.BSGL, binary=self.binary,
-            injectSources=self.injectSources)
+            injectSources=self.injectSources,
+            tCWFstatMapVersion=self.tCWFstatMapVersion)
         if self.minStartTime is None:
             self.minStartTime = self.search.minStartTime
         if self.maxStartTime is None:
diff --git a/pyfstat/pyCUDAkernels/cudaTransientFstatExpWindow.cu b/pyfstat/pyCUDAkernels/cudaTransientFstatExpWindow.cu
new file mode 100644
index 0000000000000000000000000000000000000000..85d9972c5d6cee5c6e8c4443e8c425b8cb335b8d
--- /dev/null
+++ b/pyfstat/pyCUDAkernels/cudaTransientFstatExpWindow.cu
@@ -0,0 +1,127 @@
+__global__ void cudaTransientFstatExpWindow ( float *input,
+                                              unsigned int numAtoms,
+                                              unsigned int TAtom,
+                                              unsigned int t0_data,
+                                              unsigned int win_t0,
+                                              unsigned int win_dt0,
+                                              unsigned int win_tau,
+                                              unsigned int win_dtau,
+                                              unsigned int Fmn_rows,
+                                              unsigned int Fmn_cols,
+                                              float *Fmn
+                                            )
+{
+
+  /* match CUDA thread indexing and high-level (t0,tau) indexing */
+  unsigned int m         = blockDim.x * blockIdx.x + threadIdx.x; // t0:  row
+  unsigned int n         = blockDim.y * blockIdx.y + threadIdx.y; // tau: column
+  /* unraveled 1D index for 2D output array */
+  unsigned int outidx    = Fmn_cols * m + n;
+
+  /* hardcoded copy from lalpulsar */
+  unsigned int TRANSIENT_EXP_EFOLDING = 3;
+
+  if ( (m < Fmn_rows) && (n < Fmn_cols) ) {
+
+    /* compute Fstat-atom index i_t0 in [0, numAtoms) */
+    unsigned int TAtomHalf = TAtom/2; // integer division
+    unsigned int t0 = win_t0 + m * win_dt0;
+    /* integer round: floor(x+0.5) */
+    int i_tmp = ( t0 - t0_data + TAtomHalf ) / TAtom;
+    if ( i_tmp < 0 ) {
+        i_tmp = 0;
+    }
+    unsigned int i_t0 = (unsigned int)i_tmp;
+    if ( i_t0 >= numAtoms ) {
+        i_t0 = numAtoms - 1;
+    }
+
+    /* translate n into an atoms end-index
+     * for this search interval [t0, t0+Tcoh],
+     * giving the index range of atoms to sum over
+     */
+    unsigned int tau = win_tau + n * win_dtau;
+
+    /* get end-time t1 of this transient-window search
+     * for given tau, what Tcoh should the exponential window cover?
+     * for speed reasons we want to truncate
+     * Tcoh = tau * TRANSIENT_EXP_EFOLDING
+     * with the e-folding factor chosen such that the window-value
+     * is practically negligible after that, where it will be set to 0
+     */
+//     unsigned int t1 = lround( win_t0 + TRANSIENT_EXP_EFOLDING * win_tau);
+    unsigned int t1 = t0 + TRANSIENT_EXP_EFOLDING * tau;
+
+      /* compute window end-time Fstat-atom index i_t1 in [0, numAtoms)
+       * using integer round: floor(x+0.5)
+       */
+    i_tmp = ( t1 - t0_data + TAtomHalf ) / TAtom  - 1;
+    if ( i_tmp < 0 ) {
+        i_tmp = 0;
+    }
+    unsigned int i_t1 = (unsigned int)i_tmp;
+    if ( i_t1 >= numAtoms ) {
+        i_t1 = numAtoms - 1;
+    }
+
+    /* now we have two valid atoms-indices [i_t0, i_t1]
+     * spanning our Fstat-window to sum over
+     */
+
+    float Ad    = 0.0f;
+    float Bd    = 0.0f;
+    float Cd    = 0.0f;
+    float Fa_re = 0.0f;
+    float Fa_im = 0.0f;
+    float Fb_re = 0.0f;
+    float Fb_im = 0.0f;
+
+    unsigned short input_cols = 7; // must match input matrix!
+
+    /* sum up atoms */
+    for ( unsigned int i=i_t0; i<=i_t1; i++ ) {
+
+      unsigned int t_i = t0_data + i * TAtom;
+
+      float win_i = 0.0;
+      if ( t_i >= t0 && t_i <= t1 ) {
+        float x = 1.0 * ( t_i - t0 ) / tau;
+        win_i = exp ( -x );
+      }
+
+      float win2_i = win_i * win_i;
+
+      Ad    += input[i*input_cols+0] * win2_i; // a2_alpha
+      Bd    += input[i*input_cols+1] * win2_i; // b2_alpha
+      Cd    += input[i*input_cols+2] * win2_i; // ab_alpha
+      Fa_re += input[i*input_cols+3] * win_i; // Fa_alpha_re
+      Fa_im += input[i*input_cols+4] * win_i; // Fa_alpha_im
+      Fb_re += input[i*input_cols+5] * win_i; // Fb_alpha_re
+      Fb_im += input[i*input_cols+6] * win_i; // Fb_alpha_im
+
+    }
+
+    /* get determinant */
+    float Dd = ( Ad * Bd - Cd * Cd );
+    float DdInv = 0.0f;
+    /* safety catch as in XLALWeightMultiAMCoeffs():
+     * make it so that in the end F=0 instead of -nan
+     */
+    if ( Dd > 0.0 ) {
+      DdInv  = 1.0 / Dd;
+    }
+
+    /* from XLALComputeFstatFromFaFb */
+    float F  = DdInv * (  Bd * ( Fa_re*Fa_re + Fa_im*Fa_im )
+                        + Ad * ( Fb_re*Fb_re + Fb_im*Fb_im )
+                        - 2.0 * Cd * ( Fa_re * Fb_re + Fa_im * Fb_im )
+                       );
+
+    /* store result in Fstat-matrix
+     * at unraveled index of element {m,n}
+     */
+    Fmn[outidx] = F;
+
+  } // ( (m < Fmn_rows) && (n < Fmn_cols) )
+
+} // cudaTransientFstatExpWindow()
diff --git a/pyfstat/pyCUDAkernels/cudaTransientFstatRectWindow.cu b/pyfstat/pyCUDAkernels/cudaTransientFstatRectWindow.cu
new file mode 100644
index 0000000000000000000000000000000000000000..dd353bbcf7fb52d2ff88e8b59d114057754a5f8c
--- /dev/null
+++ b/pyfstat/pyCUDAkernels/cudaTransientFstatRectWindow.cu
@@ -0,0 +1,119 @@
+__global__ void cudaTransientFstatRectWindow ( float *input,
+                                               unsigned int numAtoms,
+                                               unsigned int TAtom,
+                                               unsigned int t0_data,
+                                               unsigned int win_t0,
+                                               unsigned int win_dt0,
+                                               unsigned int win_tau,
+                                               unsigned int win_dtau,
+                                               unsigned int N_tauRange,
+                                               float *Fmn
+                                             )
+{
+
+  /* match CUDA thread indexing and high-level (t0,tau) indexing */
+  // assume 1D block, grid setup
+  unsigned int m         = blockDim.x * blockIdx.x + threadIdx.x; // t0:  row
+
+  unsigned short input_cols = 7; // must match input matrix!
+
+  /* compute Fstat-atom index i_t0 in [0, numAtoms) */
+  unsigned int TAtomHalf = TAtom/2; // integer division
+  unsigned int t0 = win_t0 + m * win_dt0;
+  /* integer round: floor(x+0.5) */
+  int i_tmp = ( t0 - t0_data + TAtomHalf ) / TAtom;
+  if ( i_tmp < 0 ) {
+    i_tmp = 0;
+  }
+  unsigned int i_t0 = (unsigned int)i_tmp;
+  if ( i_t0 >= numAtoms ) {
+    i_t0 = numAtoms - 1;
+  }
+
+  float Ad    = 0.0f;
+  float Bd    = 0.0f;
+  float Cd    = 0.0f;
+  float Fa_re = 0.0f;
+  float Fa_im = 0.0f;
+  float Fb_re = 0.0f;
+  float Fb_im = 0.0f;
+  unsigned int i_t1_last = i_t0;
+
+  /* INNER loop over timescale-parameter tau
+   * NOT parallelized so that we can still use the i_t1_last trick
+   * (empirically seems to be faster than 2D CUDA version)
+   */
+  for ( unsigned int n = 0; n < N_tauRange; n ++ ) {
+
+    if ( (m < N_tauRange) && (n < N_tauRange) ) {
+
+      /* translate n into an atoms end-index
+       * for this search interval [t0, t0+Tcoh],
+       * giving the index range of atoms to sum over
+       */
+      unsigned int tau = win_tau + n * win_dtau;
+
+      /* get end-time t1 of this transient-window search */
+      unsigned int t1 = t0 + tau;
+
+      /* compute window end-time Fstat-atom index i_t1 in [0, numAtoms)
+       * using integer round: floor(x+0.5)
+       */
+      i_tmp = ( t1 - t0_data + TAtomHalf ) / TAtom  - 1;
+      if ( i_tmp < 0 ) {
+        i_tmp = 0;
+      }
+      unsigned int i_t1 = (unsigned int)i_tmp;
+      if ( i_t1 >= numAtoms ) {
+        i_t1 = numAtoms - 1;
+      }
+
+      /* now we have two valid atoms-indices [i_t0, i_t1]
+       * spanning our Fstat-window to sum over
+       */
+
+      for ( unsigned int i = i_t1_last; i <= i_t1; i ++ ) {
+        /* sum up atoms,
+         * special optimiziation in the rectangular-window case:
+         * just add on to previous tau values,
+         * ie re-use the sum over [i_t0, i_t1_last]
+         from the pevious tau-loop iteration
+         */
+        Ad    += input[i*input_cols+0]; // a2_alpha
+        Bd    += input[i*input_cols+1]; // b2_alpha
+        Cd    += input[i*input_cols+2]; // ab_alpha
+        Fa_re += input[i*input_cols+3]; // Fa_alpha_re
+        Fa_im += input[i*input_cols+4]; // Fa_alpha_im
+        Fb_re += input[i*input_cols+5]; // Fb_alpha_re
+        Fb_im += input[i*input_cols+6]; // Fb_alpha_im
+        /* keep track of up to where we summed for the next iteration */
+        i_t1_last = i_t1 + 1;
+      }
+
+      /* get determinant */
+      float Dd = ( Ad * Bd - Cd * Cd );
+      float DdInv = 0.0f;
+      /* safety catch as in XLALWeightMultiAMCoeffs():
+       * make it so that in the end F=0 instead of -nan
+       */
+      if ( Dd > 0.0 ) {
+        DdInv  = 1.0 / Dd;
+      }
+
+      /* from XLALComputeFstatFromFaFb */
+      float F  = DdInv * (  Bd * ( Fa_re*Fa_re + Fa_im*Fa_im )
+                          + Ad * ( Fb_re*Fb_re + Fb_im*Fb_im )
+                          - 2.0 * Cd * ( Fa_re * Fb_re + Fa_im * Fb_im )
+                         );
+
+      /* store result in Fstat-matrix
+       * at unraveled index of element {m,n}
+       */
+      unsigned int outidx = m * N_tauRange + n;
+      Fmn[outidx] = F;
+
+    } // if ( (m < N_tauRange) && (n < N_tauRange) )
+
+  } // for ( unsigned int n = 0; n < N_tauRange; n ++ )
+
+} // cudaTransientFstatRectWindow()
diff --git a/pyfstat/tcw_fstat_map_funcs.py b/pyfstat/tcw_fstat_map_funcs.py
new file mode 100644
index 0000000000000000000000000000000000000000..4237fa9cfd71031df214b1dd328e9565ca174a74
--- /dev/null
+++ b/pyfstat/tcw_fstat_map_funcs.py
@@ -0,0 +1,477 @@
+""" Additional helper functions dealing with transient-CW F(t0,tau) maps """
+
+import numpy as np
+import os
+import sys
+import logging
+
+# optional imports
+import importlib as imp
+
+
+def _optional_import ( modulename, shorthand=None ):
+    '''
+    Import a module/submodule only if it's available.
+
+    using importlib instead of __import__
+    because the latter doesn't handle sub.modules
+
+    Also including a special check to fail more gracefully
+    when CUDA_DEVICE is set to too high a number.
+    '''
+
+    if shorthand is None:
+        shorthand    = modulename
+        shorthandbit = ''
+    else:
+        shorthandbit = ' as '+shorthand
+
+    try:
+        globals()[shorthand] = imp.import_module(modulename)
+        logging.debug('Successfully imported module %s%s.'
+                      % (modulename, shorthandbit))
+        success = True
+    except ImportError, e:
+        if e.message == 'No module named '+modulename:
+            logging.debug('No module {:s} found.'.format(modulename))
+            success = False
+        else:
+            raise
+
+    return success
+
+
+class pyTransientFstatMap(object):
+    '''
+    simplified object class for a F(t0,tau) F-stat map (not 2F!)
+    based on LALSuite's transientFstatMap_t type
+    replacing the gsl matrix with a numpy array
+
+    F_mn:   2D array of 2F values
+    maxF:   maximum of F (not 2F!)
+    t0_ML:  maximum likelihood transient start time t0 estimate
+    tau_ML: maximum likelihood transient duration tau estimate
+    '''
+
+    def __init__(self, N_t0Range, N_tauRange):
+        self.F_mn   = np.zeros((N_t0Range, N_tauRange), dtype=np.float32)
+        # Initializing maxF to a negative value ensures
+        # that we always update at least once and hence return
+        # sane t0_d_ML, tau_d_ML
+        # even if there is only a single bin where F=0 happens.
+        self.maxF   = float(-1.0)
+        self.t0_ML  = float(0.0)
+        self.tau_ML = float(0.0)
+
+
+# dictionary of the actual callable F-stat map functions we support,
+# if the corresponding modules are available.
+fstatmap_versions = {
+                     'lal':    lambda multiFstatAtoms, windowRange:
+                               getattr(lalpulsar,'ComputeTransientFstatMap')
+                                ( multiFstatAtoms, windowRange, False ),
+                     'pycuda': lambda multiFstatAtoms, windowRange:
+                               pycuda_compute_transient_fstat_map
+                                ( multiFstatAtoms, windowRange )
+                    }
+
+
+def init_transient_fstat_map_features ( wantCuda=False, cudaDeviceName=None ):
+    '''
+    Initialization of available modules (or "features") for F-stat maps.
+
+    Returns a dictionary of method names, to match fstatmap_versions
+    each key's value set to True only if
+    all required modules are importable on this system.
+    '''
+
+    features = {}
+
+    have_lal           = _optional_import('lal')
+    have_lalpulsar     = _optional_import('lalpulsar')
+    features['lal']    = have_lal and have_lalpulsar
+
+    # import GPU features
+    have_pycuda          = _optional_import('pycuda')
+    have_pycuda_drv      = _optional_import('pycuda.driver', 'drv')
+    have_pycuda_gpuarray = _optional_import('pycuda.gpuarray', 'gpuarray')
+    have_pycuda_tools    = _optional_import('pycuda.tools', 'cudatools')
+    have_pycuda_compiler = _optional_import('pycuda.compiler', 'cudacomp')
+    features['pycuda']   = ( have_pycuda_drv and have_pycuda_gpuarray and
+                            have_pycuda_tools and have_pycuda_compiler )
+
+    logging.debug('Got the following features for transient F-stat maps:')
+    logging.debug(features)
+
+    if wantCuda and features['pycuda']:
+        logging.debug('CUDA version: '+'.'.join(map(str,drv.get_version())))
+
+        drv.init()
+        logging.debug('Starting with default pyCUDA context,' \
+                      ' then checking all available devices...')
+        try:
+            context0 = pycuda.tools.make_default_context()
+        except pycuda._driver.LogicError, e:
+            if e.message == 'cuDeviceGet failed: invalid device ordinal':
+                devn = int(os.environ['CUDA_DEVICE'])
+                raise RuntimeError('Requested CUDA device number {} exceeds' \
+                                   ' number of available devices!' \
+                                   ' Please change through environment' \
+                                   ' variable $CUDA_DEVICE.'.format(devn))
+            else:
+                raise pycuda._driver.LogicError(e.message)
+
+        num_gpus = drv.Device.count()
+        logging.debug('Found {} CUDA device(s).'.format(num_gpus))
+
+        devices = []
+        devnames = np.empty(num_gpus,dtype='S32')
+        for n in range(num_gpus):
+            devn = drv.Device(n)
+            devices.append(devn)
+            devnames[n] = devn.name().replace(' ','-').replace('_','-')
+            logging.debug('device {}: model: {}, RAM: {}MB'.format(
+                n, devnames[n], devn.total_memory()/(2.**20) ))
+
+        if 'CUDA_DEVICE' in os.environ:
+            devnum0 = int(os.environ['CUDA_DEVICE'])
+        else:
+            devnum0 = 0
+
+        matchbit = ''
+        if cudaDeviceName:
+            # allow partial matches in device names
+            devmatches = [devidx for devidx, devname in enumerate(devnames)
+                          if cudaDeviceName in devname]
+            if len(devmatches) == 0:
+                context0.detach()
+                raise RuntimeError('Requested CUDA device "{}" not found.' \
+                                   ' Available devices: [{}]'.format(
+                                      cudaDeviceName,','.join(devnames)))
+            else:
+                devnum = devmatches[0]
+                if len(devmatches) > 1:
+                    logging.warning('Found {} CUDA devices matching name "{}".' \
+                                    ' Choosing first one with index {}.'.format(
+                                        len(devmatches),cudaDeviceName,devnum))
+            os.environ['CUDA_DEVICE'] = str(devnum)
+            matchbit =  '(matched to user request "{}")'.format(cudaDeviceName)
+        elif 'CUDA_DEVICE' in os.environ:
+            devnum = int(os.environ['CUDA_DEVICE'])
+        else:
+            devnum = 0
+        devn = devices[devnum]
+        logging.info('Choosing CUDA device {},' \
+                     ' of {} devices present: {}{}...'.format(
+                         devnum, num_gpus, devn.name(), matchbit))
+        if devnum == devnum0:
+            gpu_context = context0
+        else:
+            context0.pop()
+            gpu_context = pycuda.tools.make_default_context()
+            gpu_context.push()
+
+        _print_GPU_memory_MB('Available')
+    else:
+        gpu_context = None
+
+    return features, gpu_context
+
+
+def call_compute_transient_fstat_map ( version,
+                                       features,
+                                       multiFstatAtoms=None,
+                                       windowRange=None ):
+    '''Choose which version of the ComputeTransientFstatMap function to call.'''
+
+    if version in fstatmap_versions:
+        if features[version]:
+            FstatMap = fstatmap_versions[version](multiFstatAtoms, windowRange)
+        else:
+            raise Exception('Required module(s) for transient F-stat map' \
+                            ' method "{}" not available!'.format(version))
+    else:
+        raise Exception('Transient F-stat map method "{}"' \
+                        ' not implemented!'.format(version))
+    return FstatMap
+
+
+def reshape_FstatAtomsVector ( atomsVector ):
+    '''
+    Make a dictionary of ndarrays out of a atoms "vector" structure.
+
+    The input is a "vector"-like structure with times as the higher hierarchical
+    level and a set of "atoms" quantities defined at each timestamp.
+    The output is a dictionary with an entry for each quantity,
+    which is a 1D ndarray over timestamps for that one quantity.
+    '''
+
+    numAtoms = atomsVector.length
+    atomsDict = {}
+    atom_fieldnames = ['timestamp', 'Fa_alpha', 'Fb_alpha',
+                       'a2_alpha', 'ab_alpha', 'b2_alpha']
+    atom_dtypes     = [np.uint32, complex, complex,
+                       np.float32, np.float32, np.float32]
+    for f, field in enumerate(atom_fieldnames):
+        atomsDict[field] = np.ndarray(numAtoms,dtype=atom_dtypes[f])
+
+    for n,atom in enumerate(atomsVector.data):
+        for field in atom_fieldnames:
+            atomsDict[field][n] = atom.__getattribute__(field)
+
+    atomsDict['Fa_alpha_re'] = np.float32(atomsDict['Fa_alpha'].real)
+    atomsDict['Fa_alpha_im'] = np.float32(atomsDict['Fa_alpha'].imag)
+    atomsDict['Fb_alpha_re'] = np.float32(atomsDict['Fb_alpha'].real)
+    atomsDict['Fb_alpha_im'] = np.float32(atomsDict['Fb_alpha'].imag)
+
+    return atomsDict
+
+
+def _get_absolute_kernel_path ( kernel ):
+    pyfstatdir = os.path.dirname(os.path.abspath(os.path.realpath(__file__)))
+    kernelfile = kernel + '.cu'
+    return os.path.join(pyfstatdir,'pyCUDAkernels',kernelfile)
+
+
+def _print_GPU_memory_MB ( key ):
+    mem_used_MB  = drv.mem_get_info()[0]/(2.**20)
+    mem_total_MB = drv.mem_get_info()[1]/(2.**20)
+    logging.debug('{} GPU memory: {:.4f} / {:.4f} MB free'.format(
+                      key, mem_used_MB, mem_total_MB))
+
+
+def pycuda_compute_transient_fstat_map ( multiFstatAtoms, windowRange ):
+    '''
+    GPU version of the function to compute transient-window "F-statistic map"
+    over start-time and timescale {t0, tau}.
+    Based on XLALComputeTransientFstatMap from LALSuite,
+    (C) 2009 Reinhard Prix, licensed under GPL
+
+    Returns a 2D matrix F_mn,
+    with m = index over start-times t0,
+    and  n = index over timescales tau,
+    in steps of dt0  in [t0,  t0+t0Band],
+    and         dtau in [tau, tau+tauBand]
+    as defined in windowRange input.
+    '''
+
+    if ( windowRange.type >= lalpulsar.TRANSIENT_LAST ):
+        raise ValueError ('Unknown window-type ({}) passed as input.' \
+                          ' Allowed are [0,{}].'.format(
+                              windowRange.type, lalpulsar.TRANSIENT_LAST-1))
+
+    # internal dict for search/setup parameters
+    tCWparams = {}
+
+    # first combine all multi-atoms
+    # into a single atoms-vector with *unique* timestamps
+    tCWparams['TAtom'] = multiFstatAtoms.data[0].TAtom
+    TAtomHalf          = int(tCWparams['TAtom']/2) # integer division
+    atoms = lalpulsar.mergeMultiFstatAtomsBinned ( multiFstatAtoms,
+                                                   tCWparams['TAtom'] )
+
+    # make a combined input matrix of all atoms vectors, for transfer to GPU
+    tCWparams['numAtoms'] = atoms.length
+    atomsDict = reshape_FstatAtomsVector(atoms)
+    atomsInputMatrix = np.column_stack ( (atomsDict['a2_alpha'],
+                                          atomsDict['b2_alpha'],
+                                          atomsDict['ab_alpha'],
+                                          atomsDict['Fa_alpha_re'],
+                                          atomsDict['Fa_alpha_im'],
+                                          atomsDict['Fb_alpha_re'],
+                                          atomsDict['Fb_alpha_im'])
+                                       )
+
+    # actual data spans [t0_data, t0_data + tCWparams['numAtoms'] * TAtom]
+    # in steps of TAtom
+    tCWparams['t0_data'] = int(atoms.data[0].timestamp)
+    tCWparams['t1_data'] = int(atoms.data[tCWparams['numAtoms']-1].timestamp
+                               + tCWparams['TAtom'])
+
+    logging.debug('Transient F-stat map:' \
+                  ' t0_data={:d}, t1_data={:d}'.format(
+                      tCWparams['t0_data'], tCWparams['t1_data']))
+    logging.debug('Transient F-stat map:' \
+                  ' numAtoms={:d}, TAtom={:d},' \
+                  ' TAtomHalf={:d}'.format(
+                      tCWparams['numAtoms'], tCWparams['TAtom'], TAtomHalf))
+
+    # special treatment of window_type = none
+    # ==> replace by rectangular window spanning all the data
+    if ( windowRange.type == lalpulsar.TRANSIENT_NONE ):
+        windowRange.type    = lalpulsar.TRANSIENT_RECTANGULAR
+        windowRange.t0      = tCWparams['t0_data']
+        windowRange.t0Band  = 0
+        windowRange.dt0     = tCWparams['TAtom'] # irrelevant
+        windowRange.tau     = tCWparams['numAtoms'] * tCWparams['TAtom']
+        windowRange.tauBand = 0;
+        windowRange.dtau    = tCWparams['TAtom'] # irrelevant
+
+    """ NOTE: indices {i,j} enumerate *actual* atoms and their timestamps t_i,
+    * while the indices {m,n} enumerate the full grid of values
+    * in [t0_min, t0_max]x[Tcoh_min, Tcoh_max] in steps of deltaT.
+    * This allows us to deal with gaps in the data in a transparent way.
+    *
+    * NOTE2: we operate on the 'binned' atoms returned
+    * from XLALmergeMultiFstatAtomsBinned(),
+    * which means we can safely assume all atoms to be lined up
+    * perfectly on a 'deltaT' binned grid.
+    *
+    * The mapping used will therefore be {i,j} -> {m,n}:
+    *   m = offs_i  / deltaT
+    *   start-time offset from t0_min measured in deltaT
+    *   n = Tcoh_ij / deltaT
+    *   duration Tcoh_ij measured in deltaT,
+    *
+    * where
+    *   offs_i  = t_i - t0_min
+    *   Tcoh_ij = t_j - t_i + deltaT
+    *
+    """
+
+    # We allocate a matrix  {m x n} = t0Range * TcohRange elements
+    # covering the full transient window-range [t0,t0+t0Band]x[tau,tau+tauBand]
+    tCWparams['N_t0Range']  = int(np.floor( 1.0*windowRange.t0Band /
+                                            windowRange.dt0 ) + 1)
+    tCWparams['N_tauRange'] = int(np.floor( 1.0*windowRange.tauBand /
+                                            windowRange.dtau ) + 1)
+    FstatMap = pyTransientFstatMap ( tCWparams['N_t0Range'],
+                                     tCWparams['N_tauRange'] )
+
+    logging.debug('Transient F-stat map:' \
+                  ' N_t0Range={:d}, N_tauRange={:d},' \
+                  ' total grid points: {:d}'.format(
+                      tCWparams['N_t0Range'], tCWparams['N_tauRange'],
+                      tCWparams['N_t0Range']*tCWparams['N_tauRange']))
+
+    if ( windowRange.type == lalpulsar.TRANSIENT_RECTANGULAR ):
+        FstatMap.F_mn = pycuda_compute_transient_fstat_map_rect (
+                           atomsInputMatrix, windowRange, tCWparams )
+    elif ( windowRange.type == lalpulsar.TRANSIENT_EXPONENTIAL ):
+        FstatMap.F_mn = pycuda_compute_transient_fstat_map_exp (
+                           atomsInputMatrix, windowRange, tCWparams )
+    else:
+        raise ValueError('Invalid transient window type {}' \
+                         ' not in [{}, {}].'.format(
+                            windowRange.type, lalpulsar.TRANSIENT_NONE,
+                            lalpulsar.TRANSIENT_LAST-1))
+
+    # out of loop: get max2F and ML estimates over the m x n matrix
+    FstatMap.maxF = FstatMap.F_mn.max()
+    maxidx = np.unravel_index ( FstatMap.F_mn.argmax(),
+                               (tCWparams['N_t0Range'],
+                                tCWparams['N_tauRange']))
+    FstatMap.t0_ML  = windowRange.t0  + maxidx[0] * windowRange.dt0
+    FstatMap.tau_ML = windowRange.tau + maxidx[1] * windowRange.dtau
+
+    logging.debug('Done computing transient F-stat map.' \
+                  ' maxF={:.4f}, t0_ML={}, tau_ML={}'.format(
+                      FstatMap.maxF , FstatMap.t0_ML, FstatMap.tau_ML))
+
+    return FstatMap
+
+
+def pycuda_compute_transient_fstat_map_rect ( atomsInputMatrix,
+                                              windowRange,
+                                              tCWparams ):
+    '''
+    only GPU-parallizing outer loop,
+    keeping partial sums with memory in kernel
+    '''
+
+    # gpu data setup and transfer
+    _print_GPU_memory_MB('Initial')
+    input_gpu = gpuarray.to_gpu ( atomsInputMatrix )
+    Fmn_gpu   = gpuarray.GPUArray ( (tCWparams['N_t0Range'],
+                                     tCWparams['N_tauRange']),
+                                    dtype=np.float32 )
+    _print_GPU_memory_MB('After input+output allocation:')
+
+    # GPU kernel
+    kernel = 'cudaTransientFstatRectWindow'
+    kernelfile = _get_absolute_kernel_path ( kernel )
+    partial_Fstat_cuda_code = cudacomp.SourceModule(open(kernelfile,'r').read())
+    partial_Fstat_cuda = partial_Fstat_cuda_code.get_function(kernel)
+    partial_Fstat_cuda.prepare('PIIIIIIIIP')
+
+    # GPU grid setup
+    blockRows = min(1024,tCWparams['N_t0Range'])
+    blockCols = 1
+    gridRows  = int(np.ceil(1.0*tCWparams['N_t0Range']/blockRows))
+    gridCols  = 1
+
+    # running the kernel
+    logging.debug('Calling pyCUDA kernel with a grid of {}*{}={} blocks' \
+                  ' of {}*{}={} threads each: {} total threads...'.format(
+                      gridRows, gridCols, gridRows*gridCols,
+                      blockRows, blockCols, blockRows*blockCols,
+                      gridRows*gridCols*blockRows*blockCols))
+    partial_Fstat_cuda.prepared_call ( (gridRows,gridCols),
+                                       (blockRows,blockCols,1),
+                                       input_gpu.gpudata,
+                                       tCWparams['numAtoms'],
+                                       tCWparams['TAtom'],
+                                       tCWparams['t0_data'],
+                                       windowRange.t0, windowRange.dt0,
+                                       windowRange.tau, windowRange.dtau,
+                                       tCWparams['N_tauRange'],
+                                       Fmn_gpu.gpudata )
+
+    # return results to host
+    F_mn = Fmn_gpu.get()
+
+    _print_GPU_memory_MB('Final')
+
+    return F_mn
+
+
+def pycuda_compute_transient_fstat_map_exp ( atomsInputMatrix,
+                                             windowRange,
+                                             tCWparams ):
+    '''exponential window, inner and outer loop GPU-parallelized'''
+
+    # gpu data setup and transfer
+    _print_GPU_memory_MB('Initial')
+    input_gpu = gpuarray.to_gpu ( atomsInputMatrix )
+    Fmn_gpu   = gpuarray.GPUArray ( (tCWparams['N_t0Range'],
+                                     tCWparams['N_tauRange']),
+                                    dtype=np.float32 )
+    _print_GPU_memory_MB('After input+output allocation:')
+
+    # GPU kernel
+    kernel = 'cudaTransientFstatExpWindow'
+    kernelfile = _get_absolute_kernel_path ( kernel )
+    partial_Fstat_cuda_code = cudacomp.SourceModule(open(kernelfile,'r').read())
+    partial_Fstat_cuda = partial_Fstat_cuda_code.get_function(kernel)
+    partial_Fstat_cuda.prepare('PIIIIIIIIIP')
+
+    # GPU grid setup
+    blockRows = min(32,tCWparams['N_t0Range'])
+    blockCols = min(32,tCWparams['N_tauRange'])
+    gridRows  = int(np.ceil(1.0*tCWparams['N_t0Range']/blockRows))
+    gridCols  = int(np.ceil(1.0*tCWparams['N_tauRange']/blockCols))
+
+    # running the kernel
+    logging.debug('Calling kernel with a grid of {}*{}={} blocks' \
+                  ' of {}*{}={} threads each: {} total threads...'.format(
+                      gridRows, gridCols, gridRows*gridCols,
+                      blockRows, blockCols, blockRows*blockCols,
+                      gridRows*gridCols*blockRows*blockCols))
+    partial_Fstat_cuda.prepared_call ( (gridRows,gridCols),
+                                       (blockRows,blockCols,1),
+                                       input_gpu.gpudata,
+                                       tCWparams['numAtoms'],
+                                       tCWparams['TAtom'],
+                                       tCWparams['t0_data'],
+                                       windowRange.t0, windowRange.dt0,
+                                       windowRange.tau, windowRange.dtau,
+                                       tCWparams['N_t0Range'],
+                                       tCWparams['N_tauRange'],
+                                       Fmn_gpu.gpudata )
+
+    # return results to host
+    F_mn = Fmn_gpu.get()
+
+    _print_GPU_memory_MB('Final')
+
+    return F_mn
diff --git a/setup.py b/setup.py
index a9d69f304d33b723a3894c2f4ab88f2adfc2f737..2f03d835110d223f056e4798043888d9b241d60d 100644
--- a/setup.py
+++ b/setup.py
@@ -7,4 +7,7 @@ setup(name='PyFstat',
       author='Gregory Ashton',
       author_email='gregory.ashton@ligo.org',
       packages=['pyfstat'],
+      include_package_data=True,
+      package_data={'pyfstat': ['pyCUDAkernels/cudaTransientFstatExpWindow.cu',
+                                'pyCUDAkernels/cudaTransientFstatRectWindow.cu']},
       )