From 856dd6c232bb8abc01b95533253b97657d5c7dcd Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Wed, 21 Sep 2016 22:49:39 +0200
Subject: [PATCH] Adds binary search parameters to basic search and MCMC search

Note: this also adds a check for transient, if not called, then the
whole data is searched. This does not however mean only transient, but
also glitch like searches which use the transient tools.

Additionally fixes the plot_walkers command to allow it to plot walkers
when ndim =1 and to switch off the use of offsets
---
 pyfstat.py | 62 +++++++++++++++++++++++++++++++++++++++++-------------
 1 file changed, 47 insertions(+), 15 deletions(-)

diff --git a/pyfstat.py b/pyfstat.py
index 59ea288..22c919f 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -150,7 +150,8 @@ class ComputeFstat(object):
     @initializer
     def __init__(self, tref, sftlabel=None, sftdir=None,
                  minCoverFreq=None, maxCoverFreq=None,
-                 detector=None, earth_ephem=None, sun_ephem=None):
+                 detector=None, earth_ephem=None, sun_ephem=None,
+                 binary=False, transient=True):
 
         if earth_ephem is None:
             self.earth_ephem = self.earth_ephem_default
@@ -178,7 +179,11 @@ class ComputeFstat(object):
 
         logging.info('Initialising FstatInput')
         dFreq = 0
-        self.whatToCompute = lalpulsar.FSTATQ_ATOMS_PER_DET
+        if self.transient:
+            self.whatToCompute = lalpulsar.FSTATQ_ATOMS_PER_DET
+        else:
+            self.whatToCompute = lalpulsar.FSTATQ_2F
+
         FstatOptionalArgs = lalpulsar.FstatOptionalArgsDefaults
 
         if self.minCoverFreq is None or self.maxCoverFreq is None:
@@ -210,20 +215,29 @@ class ComputeFstat(object):
         logging.info('Initialising FstatResults')
         self.FstatResults = lalpulsar.FstatResults()
 
-        self.windowRange = lalpulsar.transientWindowRange_t()
-        self.windowRange.type = lalpulsar.TRANSIENT_RECTANGULAR
-        self.windowRange.t0Band = 0
-        self.windowRange.dt0 = 1
-        self.windowRange.tauBand = 0
-        self.windowRange.dtau = 1
+        if self.transient:
+            self.windowRange = lalpulsar.transientWindowRange_t()
+            self.windowRange.type = lalpulsar.TRANSIENT_RECTANGULAR
+            self.windowRange.t0Band = 0
+            self.windowRange.dt0 = 1
+            self.windowRange.tauBand = 0
+            self.windowRange.dtau = 1
 
     def run_computefstatistic_single_point(self, tstart, tend, F0, F1,
-                                           F2, Alpha, Delta):
+                                           F2, Alpha, Delta, asini=None,
+                                           period=None, ecc=None, tp=None,
+                                           argp=None):
         """ Compute the F-stat fully-coherently at a single point """
 
         self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0])
         self.PulsarDopplerParams.Alpha = Alpha
         self.PulsarDopplerParams.Delta = Delta
+        if self.binary:
+            self.PulsarDopplerParams.asini = asini
+            self.PulsarDopplerParams.period = period
+            self.PulsarDopplerParams.ecc = ecc
+            self.PulsarDopplerParams.tp = tp
+            self.PulsarDopplerParams.argp = argp
 
         lalpulsar.ComputeFstat(self.FstatResults,
                                self.FstatInput,
@@ -232,6 +246,9 @@ class ComputeFstat(object):
                                self.whatToCompute
                                )
 
+        if self.transient is False:
+            return self.FstatResults.twoF[0]
+
         self.windowRange.t0 = int(tstart)  # TYPE UINT4
         self.windowRange.tau = int(tend - tstart)  # TYPE UINT4
         FS = lalpulsar.ComputeTransientFstatMap(
@@ -351,8 +368,8 @@ class MCMCSearch(BaseSearchClass):
     @initializer
     def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref,
                  tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1,
-                 theta_initial=None, minCoverFreq=None,
-                 maxCoverFreq=None, scatter_val=1e-4, betas=None,
+                 theta_initial=None, minCoverFreq=None, maxCoverFreq=None,
+                 scatter_val=1e-4, betas=None, binary=False,
                  detector=None, earth_ephem=None, sun_ephem=None):
         """
         Parameters
@@ -385,6 +402,8 @@ class MCMCSearch(BaseSearchClass):
             Paths of the two files containing positions of Earth and Sun,
             respectively at evenly spaced times, as passed to CreateFstatInput
             If None defaults defined in BaseSearchClass will be used
+        binary: Bool
+            If true, search over binary parameters
 
         """
 
@@ -415,7 +434,7 @@ class MCMCSearch(BaseSearchClass):
             tref=self.tref, sftlabel=self.sftlabel,
             sftdir=self.sftdir, minCoverFreq=self.minCoverFreq,
             maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
-            sun_ephem=self.sun_ephem, detector=self.detector)
+            sun_ephem=self.sun_ephem, detector=self.detector, transient=False)
 
     def logp(self, theta_vals, theta_prior, theta_keys, search):
         H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
@@ -431,10 +450,17 @@ class MCMCSearch(BaseSearchClass):
     def unpack_input_theta(self):
         full_theta_keys = ['tstart', 'tend', 'F0', 'F1', 'F2', 'Alpha',
                            'Delta']
+        if self.binary:
+            full_theta_keys += [
+                'asini', 'period', 'ecc', 'tp', 'argp']
         full_theta_keys_copy = copy.copy(full_theta_keys)
 
         full_theta_symbols = ['_', '_', '$f$', '$\dot{f}$', '$\ddot{f}$',
                               r'$\alpha$', r'$\delta$']
+        if self.binary:
+            full_theta_symbols += [
+                'asini', 'period', 'period', 'ecc', 'tp', 'argp']
+
         self.theta_keys = []
         fixed_theta_dict = {}
         for key, val in self.theta_prior.iteritems():
@@ -685,12 +711,18 @@ class MCMCSearch(BaseSearchClass):
 
             if ndim > 1:
                 for i in range(ndim):
-                    axes[i].plot(chain[:, start:stop, i].T, color="k",
-                                 alpha=alpha)
+                    cs = chain[:, start:stop, i].T
+                    axes[i].plot(cs, color="k", alpha=alpha)
                     if symbols:
                         axes[i].set_ylabel(symbols[i])
                     if draw_vline is not None:
                         axes[i].axvline(draw_vline, lw=2, ls="--")
+                    axes[i].ticklabel_format(useOffset=False, axis='y')
+
+            else:
+                cs = chain[:, start:stop, 0].T
+                axes.plot(cs, color='k', alpha=alpha)
+                axes.ticklabel_format(useOffset=False, axis='y')
 
         return fig, axes
 
@@ -744,7 +776,7 @@ class MCMCSearch(BaseSearchClass):
             logging.warning("The sampler has produced nan's")
 
         p = pF[np.nanargmax(lnp)]
-        p0 = self._generate_scattered_p0(p)
+        p0 = self.generate_scattered_p0(p)
 
         return p0
 
-- 
GitLab