From e20d4a9ae86ddae1c78ae4d398e46e2bbb3d7757 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Fri, 18 Nov 2016 12:42:51 +0100
Subject: [PATCH] Improvements to the template estimation (V, Vsky, Vpe)

1) Moves the actual estimation to external function since it is not
contingent on using the MCMC methods etc
2) Change notation to V, Vsky and Vpe
3) Fix some table outputs (Tcoh was forcing integer)
4) Rename of names -> detector_names
5) Adds tests
6) Writes a simpler table if the sky is not searched over
---
 pyfstat.py | 223 +++++++++++++++++++++++++++++++++--------------------
 tests.py   |  36 +++++++++
 2 files changed, 174 insertions(+), 85 deletions(-)

diff --git a/pyfstat.py b/pyfstat.py
index 45d57b5..d32caac 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -125,6 +125,71 @@ def read_par(label, outdir):
     return d
 
 
+def get_V_estimate(
+        nsegs, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs,
+        fiducial_freq, detector_names, earth_ephem, sun_ephem):
+    """ Returns V, Vsky, Vpe estimated from the super-sky metric
+
+    Parameters
+    ----------
+    nsegs: int
+        Number of semi-coherent segments
+    tref: int
+        Reference time in GPS seconds
+    minStartTime, maxStartTime: int
+        Minimum and maximum SFT timestamps
+    DeltaOmega: float
+        Solid angle of the sky-patch
+    DeltaFs: array
+        Array of [DeltaF0, DeltaF1, ...], length determines the number of
+        spin-down terms.
+    fiducial_freq: float
+        Fidicual frequency
+    detector_names: array
+        Array of detectors to average over
+    earth_ephem, sun_ephem: st
+        Paths to the ephemeris files
+
+    """
+    spindowns = len(DeltaFs) - 1
+    tboundaries = np.linspace(minStartTime, maxStartTime, nsegs+1)
+
+    ref_time = lal.LIGOTimeGPS(tref)
+    segments = lal.SegListCreate()
+    for j in range(len(tboundaries)-1):
+        seg = lal.SegCreate(lal.LIGOTimeGPS(tboundaries[j]),
+                            lal.LIGOTimeGPS(tboundaries[j+1]),
+                            j)
+        lal.SegListAppend(segments, seg)
+    detNames = lal.CreateStringVector(*detector_names)
+    detectors = lalpulsar.MultiLALDetector()
+    lalpulsar.ParseMultiLALDetector(detectors, detNames)
+    detector_weights = None
+    detector_motion = (lalpulsar.DETMOTION_SPIN
+                       + lalpulsar.DETMOTION_ORBIT)
+    ephemeris = lalpulsar.InitBarycenter(earth_ephem, sun_ephem)
+    try:
+        SSkyMetric = lalpulsar.ComputeSuperskyMetrics(
+            spindowns, ref_time, segments, fiducial_freq, detectors,
+            detector_weights, detector_motion, ephemeris)
+    except RuntimeError as e:
+        logging.debug('Encountered run-time error {}'.format(e))
+        return None, None, None
+
+    sqrtdetG_SKY = np.sqrt(np.linalg.det(
+        SSkyMetric.semi_rssky_metric.data[:2, :2]))
+    sqrtdetG_PE = np.sqrt(np.linalg.det(
+        SSkyMetric.semi_rssky_metric.data[2:, 2:]))
+
+    Vsky = .5*sqrtdetG_SKY*DeltaOmega
+    Vpe = sqrtdetG_PE * np.prod(DeltaFs)
+    if Vsky == 0:
+        Vsky = 1
+    if Vpe == 0:
+        Vpe = 1
+    return (Vsky * Vpe, Vsky, Vpe)
+
+
 class BaseSearchClass(object):
     """ The base search class, provides general functions """
 
@@ -291,8 +356,8 @@ class ComputeFstat(object):
         logging.info('Loading data matching pattern {}'.format(
                      self.sftfilepath))
         SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepath, constraints)
-        names = list(set([d.header.name for d in SFTCatalog.data]))
-        self.names = names
+        detector_names = list(set([d.header.name for d in SFTCatalog.data]))
+        self.detector_names = detector_names
         SFT_timestamps = [d.header.epoch for d in SFTCatalog.data]
         try:
             from bashplotlib.histogram import plot_hist
@@ -300,10 +365,10 @@ class ComputeFstat(object):
             plot_hist(SFT_timestamps, height=5, bincount=50)
         except IOError:
             pass
-        if len(names) == 0:
+        if len(detector_names) == 0:
             raise ValueError('No data loaded.')
         logging.info('Loaded {} data files from detectors {}'.format(
-            len(SFT_timestamps), names))
+            len(SFT_timestamps), detector_names))
         logging.info('Data spans from {} ({}) to {} ({})'.format(
             int(SFT_timestamps[0]),
             subprocess.check_output('lalapps_tconvert {}'.format(
@@ -2009,74 +2074,39 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
         if prior[key]['type'] == 'unif':
             return .5*(prior[key]['upper'] + prior[key]['lower'])
 
-    def get_number_of_templates_estimate(self, nsegs):
-        """ Returns V, Vsky, Vf estimated from the super-sky metric """
-        if args.no_template_counting:
-            return 'N/A'
-        tboundaries = np.linspace(self.minStartTime, self.maxStartTime,
-                                  nsegs+1)
-
+    def init_V_estimate_parameters(self):
         if 'Alpha' in self.theta_keys:
             DeltaAlpha = self.get_width_from_prior(self.theta_prior, 'Alpha')
             DeltaDelta = self.get_width_from_prior(self.theta_prior, 'Delta')
             DeltaMid = self.get_mid_from_prior(self.theta_prior, 'Delta')
             DeltaOmega = np.sin(DeltaMid) * DeltaDelta * DeltaAlpha
+        else:
+            DeltaOmega = 0
         if 'F0' in self.theta_keys:
             DeltaF0 = self.get_width_from_prior(self.theta_prior, 'F0')
         else:
-            DeltaF0 = 1
+            raise ValueError('You are searching over F0?')
+        DeltaFs = [DeltaF0]
         if 'F1' in self.theta_keys:
             DeltaF1 = self.get_width_from_prior(self.theta_prior, 'F1')
-            spindowns = 1
-        else:
-            DeltaF1 = 1
-            spindowns = 0
-
-        ref_time = lal.LIGOTimeGPS(self.tref)
-        segments = lal.SegListCreate()
-        for j in range(len(tboundaries)-1):
-            seg = lal.SegCreate(lal.LIGOTimeGPS(tboundaries[j]),
-                                lal.LIGOTimeGPS(tboundaries[j+1]),
-                                j)
-            lal.SegListAppend(segments, seg)
+            DeltaFs.append(DeltaF1)
+        if 'F2' in self.theta_keys:
+            DeltaF2 = self.get_width_from_prior(self.theta_prior, 'F2')
+            DeltaFs.append(DeltaF2)
+
         if type(self.theta_prior['F0']) == dict:
             fiducial_freq = self.get_mid_from_prior(self.theta_prior, 'F0')
         else:
             fiducial_freq = self.theta_prior['F0']
-        detector_names = self.search.names
-        detNames = lal.CreateStringVector(*detector_names)
-        detectors = lalpulsar.MultiLALDetector()
-        lalpulsar.ParseMultiLALDetector(detectors, detNames)
-        detector_weights = None
-        detector_motion = (lalpulsar.DETMOTION_SPIN
-                           + lalpulsar.DETMOTION_ORBIT)
-        ephemeris = lalpulsar.InitBarycenter(self.earth_ephem,
-                                             self.sun_ephem)
-        try:
-            SSkyMetric = lalpulsar.ComputeSuperskyMetrics(
-                spindowns, ref_time, segments, fiducial_freq, detectors,
-                detector_weights, detector_motion, ephemeris)
-        except RuntimeError as e:
-            print e
-            return 'N/A'
-
-        sqrtdetG_SKY = np.sqrt(np.linalg.det(
-            SSkyMetric.semi_rssky_metric.data[:2, :2]))
-        sqrtdetG_F = np.sqrt(np.linalg.det(
-            SSkyMetric.semi_rssky_metric.data[2:, 2:]))
 
-        if 'Alpha' in self.theta_keys:
-            return (.5*sqrtdetG_SKY*sqrtdetG_F*DeltaOmega*DeltaF1*DeltaF0,
-                    .5*sqrtdetG_SKY*DeltaOmega,
-                    sqrtdetG_F*DeltaF1*DeltaF0)
-        else:
-            return (sqrtdetG_F*DeltaF1*DeltaF0,
-                    'N/A',
-                    sqrtdetG_F*DeltaF1*DeltaF0)
+        return fiducial_freq, DeltaOmega, DeltaFs
 
     def init_run_setup(self, run_setup, log_table=True, gen_tex_table=True):
         logging.info('Calculating the number of templates for this setup..')
-        number_of_templates = []
+        Vs = []
+        Vskys = []
+        Vpes = []
+        fiducial_freq, DeltaOmega, DeltaFs = self.init_V_estimate_parameters()
         for i, rs in enumerate(run_setup):
             rs = list(rs)
             if len(rs) == 2:
@@ -2084,51 +2114,74 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             if np.shape(rs[0]) == ():
                 rs[0] = (rs[0], 0)
             run_setup[i] = rs
-            number_of_templates.append(
-                    self.get_number_of_templates_estimate(rs[1]))
+
+            V, Vsky, Vpe = get_V_estimate(
+                rs[1], self.tref, self.minStartTime, self.maxStartTime,
+                DeltaOmega, DeltaFs, fiducial_freq, self.search.detector_names,
+                self.earth_ephem, self.sun_ephem)
+            Vs.append(V)
+            Vskys.append(Vsky)
+            Vpes.append(Vpe)
 
         if log_table:
             logging.info('Using run-setup as follow:')
-            logging.info('Stage | nburn | nprod | nsegs | Tcoh | resetp0 |'
+            logging.info('Stage | nburn | nprod | nsegs | Tcoh d | resetp0 |'
                          '# templates = # sky x # Freq')
             for i, rs in enumerate(run_setup):
                 Tcoh = (self.maxStartTime - self.minStartTime) / rs[1] / 86400
-                if number_of_templates[i] == 'N/A':
-                    vtext = number_of_templates[i]
-                elif 'N/A' in number_of_templates[i]:
-                    vtext = '{:1.0e} = {} x {:1.0e}'.format(
-                            *number_of_templates[i])
+                if Vs[i] is None:
+                    vtext = 'N/A'
                 else:
                     vtext = '{:1.0e} = {:1.0e} x {:1.0e}'.format(
-                            *number_of_templates[i])
+                            Vs[i], Vskys[i], Vpes[i])
                 logging.info('{} | {} | {} | {} | {} | {} | {}'.format(
                     str(i).ljust(5), str(rs[0][0]).ljust(5),
                     str(rs[0][1]).ljust(5), str(rs[1]).ljust(5),
-                    '{:1.2f}'.format(Tcoh).ljust(4), str(rs[2]).ljust(7), vtext))
+                    '{:6.1f}'.format(Tcoh), str(rs[2]).ljust(7),
+                    vtext))
 
         if gen_tex_table:
             filename = '{}/{}_run_setup.tex'.format(self.outdir, self.label)
-            with open(filename, 'w+') as f:
-                f.write(r'\begin{tabular}{c|cccccc}' + '\n')
-                f.write(r'Stage & $\Nseg$ & $\Tcoh^{\rm days}$ &'
-                        r'$\Nsteps$ & $\V$ & $\Vsky$ & $\Vf$ \\ \hline'
-                        '\n')
-                for i, rs in enumerate(run_setup):
-                    Tcoh = (self.maxStartTime - self.minStartTime) / rs[1] / 86400
-                    line = r'{} & {} & {} & {} & {} & {} & {} \\' + '\n'
-                    if number_of_templates[i] == 'N/A':
-                        V = Vsky = Vf = 'N/A'
-                    else:
-                        V, Vsky, Vf = number_of_templates[i]
-                    if rs[0][-1] == 0:
-                        nsteps = rs[0][0]
-                    else:
-                        nsteps = '{},{}'.format(*rs[0])
-                    line = line.format(i, rs[1], Tcoh, nsteps,
-                                       texify_float(V), texify_float(Vsky),
-                                       texify_float(Vf))
-                    f.write(line)
-                f.write(r'\end{tabular}' + '\n')
+            if DeltaOmega > 0:
+                with open(filename, 'w+') as f:
+                    f.write(r'\begin{tabular}{c|cccccc}' + '\n')
+                    f.write(r'Stage & $\Nseg$ & $\Tcoh^{\rm days}$ &'
+                            r'$\Nsteps$ & $\V$ & $\Vsky$ & $\Vpe$ \\ \hline'
+                            '\n')
+                    for i, rs in enumerate(run_setup):
+                        Tcoh = float(self.maxStartTime - self.minStartTime)/rs[1]/86400
+                        line = r'{} & {} & {} & {} & {} & {} & {} \\' + '\n'
+                        if Vs[i] is None:
+                            Vs[i] = Vskys[i] = Vpes[i] = 'N/A'
+                        if rs[0][-1] == 0:
+                            nsteps = rs[0][0]
+                        else:
+                            nsteps = '{},{}'.format(*rs[0])
+                        line = line.format(i, rs[1], Tcoh, nsteps,
+                                           texify_float(Vs[i]),
+                                           texify_float(Vskys[i]),
+                                           texify_float(Vpes[i]))
+                        f.write(line)
+                    f.write(r'\end{tabular}' + '\n')
+            else:
+                with open(filename, 'w+') as f:
+                    f.write(r'\begin{tabular}{c|cccc}' + '\n')
+                    f.write(r'Stage & $\Nseg$ & $\Tcoh^{\rm days}$ &'
+                            r'$\Nsteps$ & $\Vpe$ \\ \hline'
+                            '\n')
+                    for i, rs in enumerate(run_setup):
+                        Tcoh = float(self.maxStartTime - self.minStartTime)/rs[1]/86400
+                        line = r'{} & {} & {} & {} & {} \\' + '\n'
+                        if Vs[i] is None:
+                            Vs[i] = Vskys[i] = Vpes[i] = 'N/A'
+                        if rs[0][-1] == 0:
+                            nsteps = rs[0][0]
+                        else:
+                            nsteps = '{},{}'.format(*rs[0])
+                        line = line.format(i, rs[1], Tcoh, nsteps,
+                                           texify_float(Vpes[i]))
+                        f.write(line)
+                    f.write(r'\end{tabular}' + '\n')
 
         if args.setup_only:
             logging.info("Exit as requested by setup_only flag")
diff --git a/tests.py b/tests.py
index e87c324..c0beb61 100644
--- a/tests.py
+++ b/tests.py
@@ -194,6 +194,42 @@ class TestMCMCSearch(Test):
             FS > predicted_FS or np.abs((FS-predicted_FS))/predicted_FS < 0.3)
 
 
+class TestAuxillaryFunctions(Test):
+    nsegs = 10
+    minStartTime = 1e9
+    maxStartTime = minStartTime + 100 * 86400
+    tref = .5*(minStartTime + maxStartTime)
+    DeltaOmega = 1e-2
+    DeltaFs = [1e-4, 1e-14]
+    fiducial_freq = 100
+    detector_names = ['H1', 'L1']
+    earth_ephem = pyfstat.earth_ephem
+    sun_ephem = pyfstat.sun_ephem
+
+    def test_get_V_estimate_sky_F0_F1(self):
+
+        out = pyfstat.get_V_estimate(
+            self.nsegs, self.tref, self.minStartTime, self.maxStartTime,
+            self.DeltaOmega, self.DeltaFs, self.fiducial_freq,
+            self.detector_names, self.earth_ephem, self.sun_ephem)
+        V, Vsky, Vpe = out
+        self.assertTrue(V == Vsky * Vpe)
+        self.__class__.Vpe_COMPUTED_WITH_SKY = Vpe
+
+    def test_get_V_estimate_F0_F1(self):
+        out = pyfstat.get_V_estimate(
+            self.nsegs, self.tref, self.minStartTime, self.maxStartTime,
+            self.DeltaOmega, self.DeltaFs, self.fiducial_freq,
+            self.detector_names, self.earth_ephem, self.sun_ephem)
+        V, Vsky, Vpe = out
+        self.assertTrue(V == Vsky * Vpe)
+        self.__class__.Vpe_COMPUTED_WITHOUT_SKY = Vpe
+
+    def test_the_equivalence_of_Vpe(self):
+        """Tests if the Vpe computed with and without the sky are equal """
+        self.assertEqual(self.__class__.Vpe_COMPUTED_WITHOUT_SKY,
+                         self.__class__.Vpe_COMPUTED_WITH_SKY)
+
 if __name__ == '__main__':
     outdir = 'TestData'
     if os.path.isdir(outdir):
-- 
GitLab