From d6edafae11fb3b0e65d98643c9ee9a36a4f51f6a Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Thu, 10 Nov 2016 18:03:59 +0100
Subject: [PATCH] Adds initial check of the typical number of segments

The motivating idea here is to give the user an idea of how well the
MCMC is likely to perform: if the typical number of templates is of
order millions or more then the MCMC will work like a random template
back, jumping back and forth without finding a peak (for a long time).
---
 pyfstat.py | 97 +++++++++++++++++++++++++++++++++++++++++++++++-------
 1 file changed, 85 insertions(+), 12 deletions(-)

diff --git a/pyfstat.py b/pyfstat.py
index 511fef4..04515a0 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -53,11 +53,11 @@ parser.add_argument("-q", "--quite", help="Decrease output verbosity",
 parser.add_argument("-c", "--clean", help="Don't use cached data",
                     action="store_true")
 parser.add_argument("-u", "--use-old-data", action="store_true")
+parser.add_argument('-s', "--setup-only", action="store_true")
 parser.add_argument('unittest_args', nargs='*')
 args, unknown = parser.parse_known_args()
 sys.argv[1:] = args.unittest_args
 
-
 logger = logging.getLogger()
 logger.setLevel(logging.DEBUG)
 stream_handler = logging.StreamHandler()
@@ -268,6 +268,7 @@ class ComputeFstat(object):
                      self.sftfilepath))
         SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepath, constraints)
         names = list(set([d.header.name for d in SFTCatalog.data]))
+        self.names = names
         SFT_timestamps = [d.header.epoch for d in SFTCatalog.data]
         try:
             from bashplotlib.histogram import plot_hist
@@ -285,7 +286,7 @@ class ComputeFstat(object):
                 int(SFT_timestamps[0])), shell=True).rstrip('\n'),
             int(SFT_timestamps[-1]),
             subprocess.check_output('lalapps_tconvert {}'.format(
-                int(SFT_timestamps[-1])), shell=True)).rstrip('\n'))
+                int(SFT_timestamps[-1])), shell=True).rstrip('\n')))
 
         logging.info('Initialising ephems')
         ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem)
@@ -1581,8 +1582,8 @@ class MCMCSearch(BaseSearchClass):
             tglitches = [d['tglitch']]
         else:
             tglitches = [d['tglitch_{}'.format(i)] for i in range(self.nglitch)]
-        tbounderies = [self.tstart] + tglitches + [self.tend]
-        deltaTs = np.diff(tbounderies)
+        tboundaries = [self.tstart] + tglitches + [self.tend]
+        deltaTs = np.diff(tboundaries)
         ntrials = [time_trials + delta_F0 * dT for dT in deltaTs]
         p_val = self.p_val_twoFhat(max_twoF, ntrials)
         print('p-value = {}'.format(p_val))
@@ -1862,11 +1863,11 @@ _        sftfilepath: str
             delta_F0s[:self.theta0_idx] *= -1
             tglitches = [d['tglitch']]
 
-        tbounderies = [self.tstart] + tglitches + [self.tend]
+        tboundaries = [self.tstart] + tglitches + [self.tend]
 
         for j in range(self.nglitch+1):
-            ts = tbounderies[j]
-            te = tbounderies[j+1]
+            ts = tboundaries[j]
+            te = tboundaries[j+1]
             if (te - ts)/86400 < 5:
                 logging.info('Period too short to perform cumulative search')
                 continue
@@ -1960,6 +1961,70 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
                  BSGL=self.BSGL, run_setup=self.run_setup)
         return d
 
+    def get_width_from_prior(self, prior, key):
+        if prior[key]['type'] == 'unif':
+            return prior[key]['upper'] - prior[key]['lower']
+
+    def get_mid_from_prior(self, prior, key):
+        if prior[key]['type'] == 'unif':
+            return .5*(prior[key]['upper'] + prior[key]['lower'])
+
+    def get_number_of_templates_estimate(self, nsegs):
+        tboundaries = np.linspace(self.minStartTime, self.maxStartTime,
+                                  nsegs+1)
+        if 'Alpha' in self.theta_keys:
+            DeltaAlpha = self.get_width_from_prior(self.theta_prior, 'Alpha')
+            DeltaDelta = self.get_width_from_prior(self.theta_prior, 'Delta')
+            DeltaMid = self.get_mid_from_prior(self.theta_prior, 'Delta')
+            DeltaOmega = np.sin(DeltaMid) * DeltaDelta * DeltaAlpha
+            if 'F0' in self.theta_keys:
+                DeltaF0 = self.get_width_from_prior(self.theta_prior, 'F0')
+            else:
+                DeltaF0 = 1
+            if 'F1' in self.theta_keys:
+                DeltaF1 = self.get_width_from_prior(self.theta_prior, 'F1')
+                spindowns = 1
+            else:
+                DeltaF1 = 1
+                spindowns = 0
+
+            ref_time = lal.LIGOTimeGPS(self.tref)
+            segments = lal.SegListCreate()
+            for j in range(len(tboundaries)-1):
+                seg = lal.SegCreate(lal.LIGOTimeGPS(tboundaries[j]),
+                                    lal.LIGOTimeGPS(tboundaries[j+1]),
+                                    j)
+                lal.SegListAppend(segments, seg)
+            if type(self.theta_prior['F0']) == dict:
+                fiducial_freq = self.get_mid_from_prior(self.theta_prior, 'F0')
+            else:
+                fiducial_freq = self.theta_prior['F0']
+            detector_names = self.search.names
+            detNames = lal.CreateStringVector(*detector_names)
+            detectors = lalpulsar.MultiLALDetector()
+            lalpulsar.ParseMultiLALDetector(detectors, detNames)
+            detector_weights = None
+            detector_motion = (lalpulsar.DETMOTION_SPIN
+                               + lalpulsar.DETMOTION_ORBIT)
+            ephemeris = lalpulsar.InitBarycenter(self.earth_ephem,
+                                                 self.sun_ephem)
+            try:
+                SSkyMetric = lalpulsar.ComputeSuperskyMetrics(
+                    spindowns, ref_time, segments, fiducial_freq, detectors,
+                    detector_weights, detector_motion, ephemeris)
+                sqrtdetG_SKY = np.sqrt(np.linalg.det(
+                    SSkyMetric.semi_rssky_metric.data[:2, :2]))
+                sqrtdetG_F = np.sqrt(np.linalg.det(
+                    SSkyMetric.semi_rssky_metric.data[2:, 2:]))
+                return '{:1.0e} = 1/2 x {:1.0e} x {:1.0e}'.format(
+                        .5*sqrtdetG_SKY*sqrtdetG_F*DeltaOmega*DeltaF1*DeltaF0,
+                        sqrtdetG_SKY*DeltaOmega,
+                        sqrtdetG_F*DeltaF1*DeltaF0)
+            except RuntimeError:
+                return 'N/A'
+        elif self.theta_keys == ['F0', 'F1']:
+            return 'N/A'
+
     def run(self, run_setup, proposal_scale_factor=2):
         """ Run the follow-up with the given run_setup
 
@@ -1969,8 +2034,12 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
 
         """
 
+        self.nsegs = 1
+        self.inititate_search_object()
+
         logging.info('Using run-setup as follow:')
-        logging.info('Stage | nburn | nprod | nsegs | resetp0')
+        logging.info('Stage | nburn | nprod | nsegs | resetp0 |'
+                     '# templates = # sky x # Freq')
         for i, rs in enumerate(run_setup):
             rs = list(rs)
             if len(rs) == 2:
@@ -1978,9 +2047,15 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             if np.shape(rs[0]) == ():
                 rs[0] = (rs[0], 0)
             run_setup[i] = rs
-            logging.info('{} | {} | {} | {} | {}'.format(
+            N = self.get_number_of_templates_estimate(rs[1])
+            logging.info('{} | {} | {} | {} | {} | {}'.format(
                 str(i).ljust(5), str(rs[0][0]).ljust(5),
-                str(rs[0][1]).ljust(5), str(rs[1]).ljust(5), rs[2]))
+                str(rs[0][1]).ljust(5), str(rs[1]).ljust(5),
+                str(rs[2]).ljust(7), N))
+
+        if args.setup_only:
+            logging.info("Exit as requested by setup_only flag")
+            sys.exit()
 
         self.run_setup = run_setup
         self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
@@ -1998,8 +2073,6 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
         fig = None
         axes = None
         nsteps_total = 0
-        self.nsegs = 1
-        self.inititate_search_object()
         for j, ((nburn, nprod), nseg, reset_p0) in enumerate(run_setup):
             if j == 0:
                 p0 = self.generate_initial_p0()
-- 
GitLab