From 3b9f8349ada82e1f2341fb9e1fd471cfa8e1fd4e Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Thu, 22 Sep 2016 15:12:11 +0200
Subject: [PATCH] Subclasses grid search

Gives a more general fully-coherent grid search along with the
specialised grid glitch search
---
 pyfstat.py | 109 +++++++++++++++++++++++++++++++++++++++++++----------
 1 file changed, 89 insertions(+), 20 deletions(-)

diff --git a/pyfstat.py b/pyfstat.py
index df9145d..21037ea 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -1079,14 +1079,14 @@ class MCMCGlitchSearch(MCMCSearch):
         return p0
 
 
-class GridGlitchSearch(BaseSearchClass):
-    """ Gridded search using the SemiCoherentGlitchSearch """
+class GridSearch(BaseSearchClass):
+    """ Gridded search using ComputeFstat """
     @initializer
     def __init__(self, label, outdir, sftlabel=None, sftdir=None, F0s=[0],
-                 F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None,
-                 Alphas=[0], Deltas=[0], tref=None, tstart=None, tend=None,
-                 minCoverFreq=None, maxCoverFreq=None, write_after=1000,
-                 earth_ephem=None, sun_ephem=None):
+                 F1s=[0], F2s=[0], Alphas=[0], Deltas=[0], tref=None,
+                 tstart=None, tend=None, minCoverFreq=None, maxCoverFreq=None,
+                 write_after=1000, earth_ephem=None, sun_ephem=None,
+                 detector=None):
         """
         Parameters
         label, outdir: str
@@ -1107,8 +1107,6 @@ class GridGlitchSearch(BaseSearchClass):
             If None defaults defined in BaseSearchClass will be used
 
         """
-        if tglitchs is None:
-            self.tglitchs = [self.tend]
         if sftlabel is None:
             self.sftlabel = self.label
         if sftdir is None:
@@ -1118,17 +1116,16 @@ class GridGlitchSearch(BaseSearchClass):
         if sun_ephem is None:
             self.sun_ephem = self.sun_ephem_default
 
-        self.search = SemiCoherentGlitchSearch(
-            label=label, outdir=outdir, sftlabel=sftlabel, sftdir=sftdir,
-            tref=tref, tstart=tstart, tend=tend, minCoverFreq=minCoverFreq,
-            maxCoverFreq=maxCoverFreq, earth_ephem=self.earth_ephem,
-            sun_ephem=self.sun_ephem)
+        self.search = ComputeFstat(
+            tref=self.tref, sftlabel=self.sftlabel,
+            sftdir=self.sftdir, minCoverFreq=self.minCoverFreq,
+            maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
+            sun_ephem=self.sun_ephem, detector=self.detector, transient=False)
 
         if os.path.isdir(outdir) is False:
             os.mkdir(outdir)
         self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label)
-        self.keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta', 'delta_F0',
-                     'delta_F1', 'tglitch']
+        self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
 
     def get_array_from_tuple(self, x):
         if len(x) == 1:
@@ -1138,8 +1135,8 @@ class GridGlitchSearch(BaseSearchClass):
 
     def get_input_data_array(self):
         arrays = []
-        for tup in (self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas,
-                    self.delta_F0s, self.delta_F1s, self.tglitchs):
+        for tup in ([self.tstart], [self.tend], self.F0s, self.F1s, self.F2s,
+                    self.Alphas, self.Deltas):
             arrays.append(self.get_array_from_tuple(tup))
 
         input_data = []
@@ -1176,7 +1173,7 @@ class GridGlitchSearch(BaseSearchClass):
         counter = 0
         data = []
         for vals in self.input_data:
-            FS = self.search.compute_glitch_fstat_single(*vals)
+            FS = self.search.run_computefstatistic_single_point(*vals)
             data.append(list(vals) + [FS])
 
             if counter > self.write_after:
@@ -1188,6 +1185,14 @@ class GridGlitchSearch(BaseSearchClass):
         np.savetxt(self.out_file, data, delimiter=' ')
         self.data = np.array(data)
 
+    def plot_1D(self, xkey):
+        fig, ax = plt.subplots()
+        xidx = self.keys.index(xkey)
+        x = np.unique(self.data[:, xidx])
+        z = self.data[:, -1]
+        plt.plot(x, z)
+        fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
+
     def plot_2D(self, xkey, ykey):
         fig, ax = plt.subplots()
         xidx = self.keys.index(xkey)
@@ -1214,6 +1219,71 @@ class GridGlitchSearch(BaseSearchClass):
         return np.max(twoF)
 
 
+class GridGlitchSearch(GridSearch):
+    """ Gridded search using the SemiCoherentGlitchSearch """
+    @initializer
+    def __init__(self, label, outdir, sftlabel=None, sftdir=None, F0s=[0],
+                 F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None,
+                 Alphas=[0], Deltas=[0], tref=None, tstart=None, tend=None,
+                 minCoverFreq=None, maxCoverFreq=None, write_after=1000,
+                 earth_ephem=None, sun_ephem=None):
+        """
+        Parameters
+        label, outdir: str
+            A label and directory to read/write data from/to
+        sftlabel, sftdir: str
+            A label and directory in which to find the relevant sft file
+        F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple
+            Length 3 tuple describing the grid for each parameter, e.g
+            [F0min, F0max, dF0], for a fixed value simply give [F0].
+        tref, tstart, tend: int
+            GPS seconds of the reference time, start time and end time
+        minCoverFreq, maxCoverFreq: float
+            Minimum and maximum instantaneous frequency which will be covered
+            over the SFT time span as passed to CreateFstatInput
+        earth_ephem, sun_ephem: str
+            Paths of the two files containing positions of Earth and Sun,
+            respectively at evenly spaced times, as passed to CreateFstatInput
+            If None defaults defined in BaseSearchClass will be used
+
+        """
+        if tglitchs is None:
+            self.tglitchs = [self.tend]
+        if sftlabel is None:
+            self.sftlabel = self.label
+        if sftdir is None:
+            self.sftdir = self.outdir
+        if earth_ephem is None:
+            self.earth_ephem = self.earth_ephem_default
+        if sun_ephem is None:
+            self.sun_ephem = self.sun_ephem_default
+
+        self.search = SemiCoherentGlitchSearch(
+            label=label, outdir=outdir, sftlabel=sftlabel, sftdir=sftdir,
+            tref=tref, tstart=tstart, tend=tend, minCoverFreq=minCoverFreq,
+            maxCoverFreq=maxCoverFreq, earth_ephem=self.earth_ephem,
+            sun_ephem=self.sun_ephem)
+
+        if os.path.isdir(outdir) is False:
+            os.mkdir(outdir)
+        self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label)
+        self.keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta', 'delta_F0',
+                     'delta_F1', 'tglitch']
+
+    def get_input_data_array(self):
+        arrays = []
+        for tup in (self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas,
+                    self.delta_F0s, self.delta_F1s, self.tglitchs):
+            arrays.append(self.get_array_from_tuple(tup))
+
+        input_data = []
+        for vals in itertools.product(*arrays):
+            input_data.append(vals)
+
+        self.arrays = arrays
+        self.input_data = np.array(input_data)
+
+
 class Writer(BaseSearchClass):
     """ Instance object for generating SFTs containing glitch signals """
     @initializer
@@ -1222,7 +1292,7 @@ class Writer(BaseSearchClass):
                  delta_phi=0, delta_F0=0, delta_F1=0, delta_F2=0,
                  tref=None, phi=0, F0=30, F1=1e-10, F2=0, Alpha=5e-3,
                  Delta=6e-2, h0=0.1, cosi=0.0, psi=0.0, Tsft=1800, outdir=".",
-                 sqrtSX=1, Band=4):
+                 sqrtSX=1, Band=4, detector='H1'):
         """
         Parameters
         ----------
@@ -1271,7 +1341,6 @@ class Writer(BaseSearchClass):
         self.delta_thetas = np.atleast_2d(
                 np.array([delta_phi, delta_F0, delta_F1, delta_F2]).T)
 
-        self.detector = 'H1'
         numSFTs = int(float(self.duration) / self.Tsft)
         self.sft_filename = lalpulsar.OfficialSFTFilename(
             'H', '1', numSFTs, self.Tsft, self.tstart, self.duration,
-- 
GitLab