From be55d032f6a5d9b1043ad862fe21e02d94a4a71b Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Tue, 1 Nov 2016 13:29:42 +0100
Subject: [PATCH] Adds transient MCMC search

Also adds ability to write transient data and a search script to show
the results of such a transient search
---
 examples/make_fake_data.py              | 10 +++
 examples/transient_search_using_MCMC.py | 36 ++++++++++
 pyfstat.py                              | 89 +++++++++++++++++++++++--
 3 files changed, 131 insertions(+), 4 deletions(-)
 create mode 100644 examples/transient_search_using_MCMC.py

diff --git a/examples/make_fake_data.py b/examples/make_fake_data.py
index 49f450f..62c9d95 100644
--- a/examples/make_fake_data.py
+++ b/examples/make_fake_data.py
@@ -54,3 +54,13 @@ two_glitch_data = Writer(
     delta_F1=delta_F1, delta_F2=delta_F2)
 two_glitch_data.make_data()
 
+
+# Making transient data in the middle third
+data_tstart = tstart - duration
+data_duration = 3 * duration
+transient = Writer(
+    label='transient', outdir='data', tref=tref, tstart=tstart, F0=F0, F1=F1,
+    F2=F2, duration=duration, Alpha=Alpha, Delta=Delta, h0=h0, sqrtSX=sqrtSX,
+    data_tstart=data_tstart, data_duration=data_duration)
+transient.make_data()
+
diff --git a/examples/transient_search_using_MCMC.py b/examples/transient_search_using_MCMC.py
new file mode 100644
index 0000000..7fc7202
--- /dev/null
+++ b/examples/transient_search_using_MCMC.py
@@ -0,0 +1,36 @@
+from pyfstat import MCMCTransientSearch
+
+F0 = 30.0
+F1 = -1e-10
+F2 = 0
+Alpha = 5e-3
+Delta = 6e-2
+tref = 362750407.0
+
+tstart = 1000000000
+duration = 100*86400
+tstart = 1000000000 - duration
+tend = tstart + 3*duration
+
+theta_prior = {'F0': {'type': 'unif', 'lower': F0*(1-1e-6), 'upper': F0*(1+1e-6)},
+               'F1': {'type': 'unif', 'lower': F1*(1+1e-2), 'upper': F1*(1-1e-2)},
+               'F2': F2,
+               'Alpha': Alpha,
+               'Delta': Delta,
+               'transient_tstart': {'type': 'unif', 'lower': tstart, 'upper': tend},
+               'transient_duration': {'type': 'halfnorm', 'loc':0, 'scale': duration}
+               }
+
+ntemps = 4
+log10temperature_min = -1
+nwalkers = 100
+nsteps = [1000, 1000]
+
+mcmc = MCMCTransientSearch(
+    label='transient_search_using_MCMC', outdir='data',
+    sftfilepath='data/*transient*sft', theta_prior=theta_prior, tref=tref,
+    tstart=tstart, tend=tend, nsteps=nsteps, nwalkers=nwalkers, ntemps=ntemps,
+    log10temperature_min=log10temperature_min)
+mcmc.run()
+mcmc.plot_corner(add_prior=True)
+mcmc.print_summary()
diff --git a/pyfstat.py b/pyfstat.py
index a3982d3..fdb9016 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -1657,6 +1657,77 @@ _        sftfilepath: str
         fig.savefig('{}/{}_twoFcumulative.png'.format(self.outdir, self.label))
 
 
+class MCMCTransientSearch(MCMCSearch):
+    """ MCMC search for a transient signal using the ComputeFstat """
+
+    def inititate_search_object(self):
+        logging.info('Setting up search object')
+        self.search = ComputeFstat(
+            tref=self.tref, sftfilepath=self.sftfilepath,
+            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
+            earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
+            detector=self.detector, transient=True,
+            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
+            BSGL=self.BSGL)
+
+    def logl(self, theta, search):
+        for j, theta_i in enumerate(self.theta_idxs):
+            self.fixed_theta[theta_i] = theta[j]
+        if self.fixed_theta[1] < 86400:
+            return 0
+        self.fixed_theta[1] += self.fixed_theta[0]
+        if self.fixed_theta[1] > self.tend:
+            return 0
+        FS = search.run_computefstatistic_single_point(*self.fixed_theta)
+        return FS
+
+    def unpack_input_theta(self):
+        full_theta_keys = ['transient_tstart',
+                           'transient_duration', 'F0', 'F1', 'F2', 'Alpha',
+                           'Delta']
+        if self.binary:
+            full_theta_keys += [
+                'asini', 'period', 'ecc', 'tp', 'argp']
+        full_theta_keys_copy = copy.copy(full_theta_keys)
+
+        full_theta_symbols = [r'$t_{\rm start}$', r'$\Delta T$',
+                              '$f$', '$\dot{f}$', '$\ddot{f}$',
+                              r'$\alpha$', r'$\delta$']
+        if self.binary:
+            full_theta_symbols += [
+                'asini', 'period', 'period', 'ecc', 'tp', 'argp']
+
+        self.theta_keys = []
+        fixed_theta_dict = {}
+        self.theta_prior.pop('tstart')
+        self.theta_prior.pop('tend')
+        for key, val in self.theta_prior.iteritems():
+            if type(val) is dict:
+                fixed_theta_dict[key] = 0
+                self.theta_keys.append(key)
+            elif type(val) in [float, int, np.float64]:
+                fixed_theta_dict[key] = val
+            else:
+                raise ValueError(
+                    'Type {} of {} in theta not recognised'.format(
+                        type(val), key))
+            full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
+
+        if len(full_theta_keys_copy) > 0:
+            raise ValueError(('Input dictionary `theta` is missing the'
+                              'following keys: {}').format(
+                                  full_theta_keys_copy))
+
+        self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
+        self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
+        self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]
+
+        idxs = np.argsort(self.theta_idxs)
+        self.theta_idxs = [self.theta_idxs[i] for i in idxs]
+        self.theta_symbols = [self.theta_symbols[i] for i in idxs]
+        self.theta_keys = [self.theta_keys[i] for i in idxs]
+
+
 class GridSearch(BaseSearchClass):
     """ Gridded search using ComputeFstat """
     @initializer
@@ -1943,7 +2014,8 @@ class Writer(BaseSearchClass):
                  delta_phi=0, delta_F0=0, delta_F1=0, delta_F2=0,
                  tref=None, phi=0, F0=30, F1=1e-10, F2=0, Alpha=5e-3,
                  Delta=6e-2, h0=0.1, cosi=0.0, psi=0.0, Tsft=1800, outdir=".",
-                 sqrtSX=1, Band=4, detector='H1'):
+                 sqrtSX=1, Band=4, detector='H1', data_tstart=None,
+                 data_duration=None):
         """
         Parameters
         ----------
@@ -1963,6 +2035,9 @@ class Writer(BaseSearchClass):
             pre-glitch phase, frequency, sky-position, and signal properties
         Tsft: float
             the sft duration
+        data_tstart, data_duration: float
+            if not None, the total span of data, this can be used to generate
+            transient signals
 
         see `lalapps_Makefakedata_v5 --help` for help with the other paramaters
         """
@@ -1974,7 +2049,7 @@ class Writer(BaseSearchClass):
         if self.dtglitch is None or self.dtglitch == self.duration:
             self.tbounds = [self.tstart, self.tend]
         elif np.size(self.dtglitch) == 1:
-           self.tbounds = [self.tstart, self.tstart+self.dtglitch, self.tend]
+            self.tbounds = [self.tstart, self.tstart+self.dtglitch, self.tend]
         else:
             self.tglitch = self.tstart + np.array(self.dtglitch)
             self.tbounds = [self.tstart] + list(self.tglitch) + [self.tend]
@@ -2129,8 +2204,14 @@ transientTauDays={:1.3f}\n""")
         cl.append('--outLabel="{}"'.format(self.label))
         cl.append('--IFOs="{}"'.format(self.detector))
         cl.append('--sqrtSX="{}"'.format(self.sqrtSX))
-        cl.append('--startTime={:10.9f}'.format(float(self.tstart)))
-        cl.append('--duration={}'.format(int(self.duration)))
+        if self.data_tstart is None:
+            cl.append('--startTime={:10.9f}'.format(float(self.tstart)))
+        else:
+            cl.append('--startTime={:10.9f}'.format(float(self.data_tstart)))
+        if self.data_duration is None:
+            cl.append('--duration={}'.format(int(self.duration)))
+        else:
+            cl.append('--duration={}'.format(int(self.data_duration)))
         cl.append('--fmin={}'.format(int(self.fmin)))
         cl.append('--Band={}'.format(self.Band))
         cl.append('--Tsft={}'.format(self.Tsft))
-- 
GitLab