From d75d67b5d2322c9d004419cea8e810cbc0ee2122 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Fri, 28 Jul 2017 18:02:56 +0200
Subject: [PATCH] Initial (untested) changes to proper calculation of the
 volume

---
 pyfstat/helper_functions.py        |   6 ++
 pyfstat/mcmc_based_searches.py     |   4 +-
 pyfstat/optimal_setup_functions.py | 109 ++++++++++++++++++++---------
 3 files changed, 83 insertions(+), 36 deletions(-)

diff --git a/pyfstat/helper_functions.py b/pyfstat/helper_functions.py
index 25a5601..fbb3147 100644
--- a/pyfstat/helper_functions.py
+++ b/pyfstat/helper_functions.py
@@ -11,6 +11,7 @@ import inspect
 import peakutils
 from functools import wraps
 from scipy.stats.distributions import ncx2
+import lal
 
 import matplotlib.pyplot as plt
 import numpy as np
@@ -203,3 +204,8 @@ def run_commandline (cl):
     os.system('\n')
 
     return(out)
+
+def convert_array_to_gsl_matrix(array):
+    gsl_matrix =  lal.gsl_matrix(*array.shape)
+    gsl_matrix.data = array
+    return gsl_matrix
diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index 3b7a6d4..cf448bd 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -1908,7 +1908,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             if generate_setup:
                 nsegs_vals, V_vals = get_optimal_setup(
                     R, Nsegs0, self.tref, self.minStartTime,
-                    self.maxStartTime, DeltaOmega, DeltaFs, fiducial_freq,
+                    self.maxStartTime, self.theta_prior, fiducial_freq,
                     self.search.detector_names, self.earth_ephem,
                     self.sun_ephem)
                 self.write_setup_input_file(run_setup_input_file, R, Nsegs0,
@@ -1936,7 +1936,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
                 else:
                     V = get_V_estimate(
                         rs[1], self.tref, self.minStartTime, self.maxStartTime,
-                        DeltaOmega, DeltaFs, fiducial_freq,
+                        self.theta_prior, fiducial_freq,
                         self.search.detector_names, self.earth_ephem,
                         self.sun_ephem)
                     V_vals.append(V)
diff --git a/pyfstat/optimal_setup_functions.py b/pyfstat/optimal_setup_functions.py
index 0108246..543fed8 100644
--- a/pyfstat/optimal_setup_functions.py
+++ b/pyfstat/optimal_setup_functions.py
@@ -10,17 +10,18 @@ import numpy as np
 import scipy.optimize
 import lal
 import lalpulsar
+import helper_functions
 
 
 def get_optimal_setup(
-        R, Nsegs0, tref, minStartTime, maxStartTime, DeltaOmega,
-        DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem):
+        R, Nsegs0, tref, minStartTime, maxStartTime, prior, fiducial_freq,
+        detector_names, earth_ephem, sun_ephem):
     logging.info('Calculating optimal setup for R={}, Nsegs0={}'.format(
         R, Nsegs0))
 
     V_0 = get_V_estimate(
-        Nsegs0, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs,
-        fiducial_freq, detector_names, earth_ephem, sun_ephem)
+        Nsegs0, tref, minStartTime, maxStartTime, prior, fiducial_freq,
+        detector_names, earth_ephem, sun_ephem)
     logging.info('Stage {}, nsegs={}, V={}'.format(0, Nsegs0, V_0))
 
     nsegs_vals = [Nsegs0]
@@ -30,8 +31,8 @@ def get_optimal_setup(
     nsegs_i = Nsegs0
     while nsegs_i > 1:
         nsegs_i, V_i = get_nsegs_ip1(
-            nsegs_i, R, tref, minStartTime, maxStartTime, DeltaOmega,
-            DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem)
+            nsegs_i, R, tref, minStartTime, maxStartTime, prior, fiducial_freq,
+            detector_names, earth_ephem, sun_ephem)
         nsegs_vals.append(nsegs_i)
         V_vals.append(V_i)
         i += 1
@@ -42,13 +43,13 @@ def get_optimal_setup(
 
 
 def get_nsegs_ip1(
-        nsegs_i, R, tref, minStartTime, maxStartTime, DeltaOmega,
-        DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem):
+        nsegs_i, R, tref, minStartTime, maxStartTime, prior, fiducial_freq,
+        detector_names, earth_ephem, sun_ephem):
 
     log10R = np.log10(R)
     log10Vi = np.log10(get_V_estimate(
-        nsegs_i, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs,
-        fiducial_freq, detector_names, earth_ephem, sun_ephem))
+        nsegs_i, tref, minStartTime, maxStartTime, prior, fiducial_freq,
+        detector_names, earth_ephem, sun_ephem))
 
     def f(nsegs_ip1):
         if nsegs_ip1[0] > nsegs_i:
@@ -59,8 +60,8 @@ def get_nsegs_ip1(
         if nsegs_ip1 == 0:
             nsegs_ip1 = 1
         Vip1 = get_V_estimate(
-            nsegs_ip1, tref, minStartTime, maxStartTime, DeltaOmega,
-            DeltaFs, fiducial_freq, detector_names, earth_ephem, sun_ephem)
+            nsegs_ip1, tref, minStartTime, maxStartTime, prior, fiducial_freq,
+            detector_names, earth_ephem, sun_ephem)
         if Vip1 is None:
             return 1e6
         else:
@@ -73,15 +74,47 @@ def get_nsegs_ip1(
         nsegs_ip1 = 1
     if res.success:
         return nsegs_ip1, get_V_estimate(
-            nsegs_ip1, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs,
-            fiducial_freq, detector_names, earth_ephem, sun_ephem)
+            nsegs_ip1, tref, minStartTime, maxStartTime, prior, fiducial_freq,
+            detector_names, earth_ephem, sun_ephem)
     else:
         raise ValueError('Optimisation unsuccesful')
 
 
+def get_parallelepiped(prior):
+    keys = ['Alpha', 'Delta', 'F0', 'F1', 'F2']
+    spindown_keys = keys[3:]
+    sky_keys = keys[:2]
+    lims = []
+    lims_keys = []
+    lims_idxs = []
+    for i, key in enumerate(keys):
+        if type(prior[key]) == dict:
+            if prior[key]['type'] == 'unif':
+                lims.append([prior[key]['lower'], prior[key]['upper']])
+                lims_keys.append(key)
+                lims_idxs.append(i)
+            else:
+                raise ValueError(
+                    "Prior type {} not yet supported".format(
+                        prior[key]['type']))
+        elif key not in spindown_keys:
+            lims.append([prior[key], 0])
+    lims = np.array(lims)
+    lims_keys = np.array(lims_keys)
+    base = lims[:, 0]
+    p = [base]
+    for i in lims_idxs:
+        basex = base.copy()
+        basex[i] = lims[i, 1]
+        p.append(basex)
+    spindowns = np.sum([np.sum(lims_keys == k) for k in spindown_keys])
+    sky = any([key in lims_keys for key in sky_keys])
+    return np.array(p).T, spindowns , sky
+
+
 def get_V_estimate(
-        nsegs, tref, minStartTime, maxStartTime, DeltaOmega, DeltaFs,
-        fiducial_freq, detector_names, earth_ephem, sun_ephem):
+        nsegs, tref, minStartTime, maxStartTime, prior, fiducial_freq,
+        detector_names, earth_ephem, sun_ephem):
     """ Returns V estimated from the super-sky metric
 
     Parameters
@@ -92,11 +125,8 @@ def get_V_estimate(
         Reference time in GPS seconds
     minStartTime, maxStartTime: int
         Minimum and maximum SFT timestamps
-    DeltaOmega: float
-        Solid angle of the sky-patch
-    DeltaFs: array
-        Array of [DeltaF0, DeltaF1, ...], length determines the number of
-        spin-down terms.
+    prior: dict
+        The prior dictionary
     fiducial_freq: float
         Fidicual frequency
     detector_names: array
@@ -105,7 +135,12 @@ def get_V_estimate(
         Paths to the ephemeris files
 
     """
-    spindowns = len(DeltaFs) - 1
+    in_phys, spindowns, sky = get_parallelepiped(prior)
+    out_rssky = np.zeros(in_phys.shape)
+
+    in_phys = helper_functions.convert_array_to_gsl_matrix(in_phys)
+    out_rssky = helper_functions.convert_array_to_gsl_matrix(out_rssky)
+
     tboundaries = np.linspace(minStartTime, maxStartTime, nsegs+1)
 
     ref_time = lal.LIGOTimeGPS(tref)
@@ -130,15 +165,21 @@ def get_V_estimate(
         logging.debug('Encountered run-time error {}'.format(e))
         return None, None, None
 
-    sqrtdetG_SKY = np.sqrt(np.linalg.det(
-        SSkyMetric.semi_rssky_metric.data[:2, :2]))
-    sqrtdetG_PE = np.sqrt(np.linalg.det(
-        SSkyMetric.semi_rssky_metric.data[2:, 2:]))
-
-    Vsky = .5*sqrtdetG_SKY*DeltaOmega
-    Vpe = sqrtdetG_PE * np.prod(DeltaFs)
-    if Vsky == 0:
-        Vsky = 1
-    if Vpe == 0:
-        Vpe = 1
-    return Vsky * Vpe
+    if sky:
+        i = 0
+    else:
+        i = 2
+
+    lalpulsar.ConvertPhysicalToSuperskyPoints(
+        out_rssky, in_phys, SSkyMetric.semi_rssky_transf)
+
+    parallelepiped = (out_rssky.data[i:, 1:].T - out_rssky.data[i:, 0]).T
+
+    sqrtdetG = np.sqrt(np.linalg.det(
+        SSkyMetric.semi_rssky_metric.data[i:, i:]))
+
+    dV = np.abs(np.linalg.det(parallelepiped))
+
+    V = sqrtdetG * dV
+
+    return V
-- 
GitLab