From b0424e38b39e17246955817254798d1254bee1eb Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Fri, 22 Dec 2017 15:24:47 +0100
Subject: [PATCH] Create TransientGridSearch to provide transient grid searches

Previously, this functionality was part of GridSearch, this splits this
into a separate subclass to ease future development.
---
 .../short_transient_search_gridded.py         |   4 +-
 pyfstat/__init__.py                           |   2 +-
 pyfstat/grid_based_searches.py                | 134 ++++++++++++++----
 3 files changed, 107 insertions(+), 33 deletions(-)

diff --git a/examples/transient_examples/short_transient_search_gridded.py b/examples/transient_examples/short_transient_search_gridded.py
index 64e6423..8fc801f 100644
--- a/examples/transient_examples/short_transient_search_gridded.py
+++ b/examples/transient_examples/short_transient_search_gridded.py
@@ -30,7 +30,7 @@ Alphas = [Alpha]
 Deltas = [Delta]
 
 print('Standard CW search:')
-search1 = pyfstat.GridSearch(
+search1 = pyfstat.TransientGridSearch(
     label='CW', outdir=datadir,
     sftfilepattern=os.path.join(datadir,'*simulated_transient_signal*sft'),
     F0s=F0s, F1s=F1s, F2s=F2s, Alphas=Alphas, Deltas=Deltas, tref=tref,
@@ -44,7 +44,7 @@ search1.plot_1D(xkey='F0',
                xlabel='freq [Hz]', ylabel='$2\mathcal{F}$')
 
 print('with t0,tau bands:')
-search2 = pyfstat.GridSearch(
+search2 = pyfstat.TransientGridSearch(
     label='tCW', outdir=datadir,
     sftfilepattern=os.path.join(datadir,'*simulated_transient_signal*sft'),
     F0s=F0s, F1s=F1s, F2s=F2s, Alphas=Alphas, Deltas=Deltas, tref=tref,
diff --git a/pyfstat/__init__.py b/pyfstat/__init__.py
index 5c44708..3516c4a 100644
--- a/pyfstat/__init__.py
+++ b/pyfstat/__init__.py
@@ -3,4 +3,4 @@ from __future__ import division as _division
 from .core import BaseSearchClass, ComputeFstat, SemiCoherentSearch, SemiCoherentGlitchSearch
 from .make_sfts import Writer, GlitchWriter, FrequencyModulatedArtifactWriter, FrequencyAmplitudeModulatedArtifactWriter
 from .mcmc_based_searches import MCMCSearch, MCMCGlitchSearch, MCMCSemiCoherentSearch, MCMCFollowUpSearch, MCMCTransientSearch
-from .grid_based_searches import GridSearch, GridUniformPriorSearch, GridGlitchSearch, FrequencySlidingWindow, DMoff_NO_SPIN, SliceGridSearch
+from .grid_based_searches import GridSearch, GridUniformPriorSearch, GridGlitchSearch, FrequencySlidingWindow, DMoff_NO_SPIN, SliceGridSearch, TransientGridSearch
diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py
index f9e2699..233c60c 100644
--- a/pyfstat/grid_based_searches.py
+++ b/pyfstat/grid_based_searches.py
@@ -34,9 +34,7 @@ class GridSearch(BaseSearchClass):
                  Deltas, tref=None, minStartTime=None, maxStartTime=None,
                  nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None,
                  detectors=None, SSBprec=None, injectSources=None,
-                 input_arrays=False, assumeSqrtSX=None,
-                 transientWindowType=None, t0Band=None, tauBand=None,
-                 outputTransientFstatMap=False):
+                 input_arrays=False, assumeSqrtSX=None):
         """
         Parameters
         ----------
@@ -53,19 +51,6 @@ class GridSearch(BaseSearchClass):
             GPS seconds of the reference time, start time and end time
         input_arrays: bool
             if true, use the F0s, F1s, etc as is
-        transientWindowType: str
-            If 'rect' or 'exp', compute atoms so that a transient (t0,tau) map
-            can later be computed.  ('none' instead of None explicitly calls
-            the transient-window function, but with the full range, for
-            debugging). Currently only supported for nsegs=1.
-        t0Band, tauBand: int
-            if >0, search t0 in (minStartTime,minStartTime+t0Band)
-                   and tau in (2*Tsft,2*Tsft+tauBand).
-            if =0, only compute CW Fstat with t0=minStartTime,
-                   tau=maxStartTime-minStartTime.
-        outputTransientFstatMap: bool
-            if true, write output files for (t0,tau) Fstat maps
-            (one file for each doppler grid point!)
 
         For all other parameters, see `pyfstat.ComputeFStat` for details
         """
@@ -85,8 +70,6 @@ class GridSearch(BaseSearchClass):
                 tref=self.tref, sftfilepattern=self.sftfilepattern,
                 minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
                 detectors=self.detectors,
-                transientWindowType=self.transientWindowType,
-                t0Band=self.t0Band, tauBand=self.tauBand,
                 minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
                 BSGL=self.BSGL, SSBprec=self.SSBprec,
                 injectSources=self.injectSources,
@@ -170,19 +153,7 @@ class GridSearch(BaseSearchClass):
         data = []
         for vals in tqdm(self.input_data):
             detstat = self.search.get_det_stat(*vals)
-            windowRange = getattr(self.search, 'windowRange', None)
-            FstatMap = getattr(self.search, 'FstatMap', None)
             thisCand = list(vals) + [detstat]
-            if getattr(self, 'transientWindowType', None):
-                if self.outputTransientFstatMap:
-                    tCWfile = os.path.splitext(self.out_file)[0]+'_tCW_%.16f_%.16f_%.16f_%.16g_%.16g.dat' % (vals[2],vals[5],vals[6],vals[3],vals[4]) # freq alpha delta f1dot f2dot
-                    fo = lal.FileOpen(tCWfile, 'w')
-                    lalpulsar.write_transientFstatMap_to_fp ( fo, FstatMap, windowRange, None )
-                    del fo # instead of lal.FileClose() which is not SWIG-exported
-                Fmn = FstatMap.F_mn.data
-                maxidx = np.unravel_index(Fmn.argmax(), Fmn.shape)
-                thisCand += [windowRange.t0+maxidx[0]*windowRange.dt0,
-                             windowRange.tau+maxidx[1]*windowRange.dtau]
             data.append(thisCand)
 
         data = np.array(data, dtype=np.float)
@@ -372,6 +343,109 @@ class GridSearch(BaseSearchClass):
                 self.outdir, self.label, dets, type(self).__name__)
 
 
+class TransientGridSearch(GridSearch):
+    """ Gridded transient-continous search using ComputeFstat """
+
+    @helper_functions.initializer
+    def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
+                 Deltas, tref=None, minStartTime=None, maxStartTime=None,
+                 BSGL=False, minCoverFreq=None, maxCoverFreq=None,
+                 detectors=None, SSBprec=None, injectSources=None,
+                 input_arrays=False, assumeSqrtSX=None,
+                 transientWindowType=None, t0Band=None, tauBand=None,
+                 outputTransientFstatMap=False):
+        """
+        Parameters
+        ----------
+        label, outdir: str
+            A label and directory to read/write data from/to
+        sftfilepattern: str
+            Pattern to match SFTs using wildcards (*?) and ranges [0-9];
+            mutiple patterns can be given separated by colons.
+        F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple
+            Length 3 tuple describing the grid for each parameter, e.g
+            [F0min, F0max, dF0], for a fixed value simply give [F0]. Unless
+            input_arrays == True, then these are the values to search at.
+        tref, minStartTime, maxStartTime: int
+            GPS seconds of the reference time, start time and end time
+        input_arrays: bool
+            if true, use the F0s, F1s, etc as is
+        transientWindowType: str
+            If 'rect' or 'exp', compute atoms so that a transient (t0,tau) map
+            can later be computed.  ('none' instead of None explicitly calls
+            the transient-window function, but with the full range, for
+            debugging). Currently only supported for nsegs=1.
+        t0Band, tauBand: int
+            if >0, search t0 in (minStartTime,minStartTime+t0Band)
+                   and tau in (2*Tsft,2*Tsft+tauBand).
+            if =0, only compute CW Fstat with t0=minStartTime,
+                   tau=maxStartTime-minStartTime.
+        outputTransientFstatMap: bool
+            if true, write output files for (t0,tau) Fstat maps
+            (one file for each doppler grid point!)
+
+        For all other parameters, see `pyfstat.ComputeFStat` for details
+        """
+
+        self.nsegs = 1
+        if os.path.isdir(outdir) is False:
+            os.mkdir(outdir)
+        self.set_out_file()
+        self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
+        self.search_keys = [x+'s' for x in self.keys[2:]]
+        for k in self.search_keys:
+            setattr(self, k, np.atleast_1d(getattr(self, k)))
+
+    def inititate_search_object(self):
+        logging.info('Setting up search object')
+        self.search = ComputeFstat(
+            tref=self.tref, sftfilepattern=self.sftfilepattern,
+            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
+            detectors=self.detectors,
+            transientWindowType=self.transientWindowType,
+            t0Band=self.t0Band, tauBand=self.tauBand,
+            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
+            BSGL=self.BSGL, SSBprec=self.SSBprec,
+            injectSources=self.injectSources,
+            assumeSqrtSX=self.assumeSqrtSX)
+        self.search.get_det_stat = self.search.get_fullycoherent_twoF
+
+    def run(self, return_data=False):
+        self.get_input_data_array()
+        old_data = self.check_old_data_is_okay_to_use()
+        if old_data is not False:
+            self.data = old_data
+            return
+
+        if hasattr(self, 'search') is False:
+            self.inititate_search_object()
+
+        data = []
+        for vals in tqdm(self.input_data):
+            detstat = self.search.get_det_stat(*vals)
+            windowRange = getattr(self.search, 'windowRange', None)
+            FstatMap = getattr(self.search, 'FstatMap', None)
+            thisCand = list(vals) + [detstat]
+            if getattr(self, 'transientWindowType', None):
+                if self.outputTransientFstatMap:
+                    tCWfile = os.path.splitext(self.out_file)[0]+'_tCW_%.16f_%.16f_%.16f_%.16g_%.16g.dat' % (vals[2],vals[5],vals[6],vals[3],vals[4]) # freq alpha delta f1dot f2dot
+                    fo = lal.FileOpen(tCWfile, 'w')
+                    lalpulsar.write_transientFstatMap_to_fp ( fo, FstatMap, windowRange, None )
+                    del fo # instead of lal.FileClose() which is not SWIG-exported
+                Fmn = FstatMap.F_mn.data
+                maxidx = np.unravel_index(Fmn.argmax(), Fmn.shape)
+                thisCand += [windowRange.t0+maxidx[0]*windowRange.dt0,
+                             windowRange.tau+maxidx[1]*windowRange.dtau]
+            data.append(thisCand)
+
+        data = np.array(data, dtype=np.float)
+        if return_data:
+            return data
+        else:
+            self.save_array_to_disk(data)
+            self.data = data
+
+
 class SliceGridSearch(GridSearch):
     """ Slice gridded search using ComputeFstat """
     @helper_functions.initializer
-- 
GitLab