From 8435f54d19002576ee5174cff7b7310f9eaf3f98 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Wed, 21 Sep 2016 18:40:25 +0200
Subject: [PATCH] Splits up the MCMC classes

This makes the MCMCGlitchSearch a subclass of the more general
MCMCSearch
---
 pyfstat.py     | 276 +++++++++++++++++++++++++++++++++++--------------
 tests/tests.py |  12 +--
 2 files changed, 205 insertions(+), 83 deletions(-)

diff --git a/pyfstat.py b/pyfstat.py
index adeed2b..59ea288 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -346,15 +346,14 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat):
         return twoFsegA + twoFsegB
 
 
-class MCMCGlitchSearch(BaseSearchClass):
-    """ MCMC search using the SemiCoherentGlitchSearch """
+class MCMCSearch(BaseSearchClass):
+    """ MCMC search using ComputeFstat"""
     @initializer
     def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref,
                  tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1,
-                 nglitch=0, theta_initial=None, minCoverFreq=None,
+                 theta_initial=None, minCoverFreq=None,
                  maxCoverFreq=None, scatter_val=1e-4, betas=None,
-                 detector=None, dtglitchmin=20*86400, earth_ephem=None,
-                 sun_ephem=None):
+                 detector=None, earth_ephem=None, sun_ephem=None):
         """
         Parameters
         label, outdir: str
@@ -370,8 +369,6 @@ class MCMCGlitchSearch(BaseSearchClass):
             Either a dictionary of distribution about which to distribute the
             initial walkers about, an array (from which the walkers will be
             scattered by scatter_val, or  None in which case the prior is used.
-        nglitch: int
-            The number of glitches to allow
         tref, tstart, tend: int
             GPS seconds of the reference time, start time and end time
         nsteps: list (m,)
@@ -379,9 +376,6 @@ class MCMCGlitchSearch(BaseSearchClass):
             give the nburn and nprod of the 'production' run, all entries
             before are for iterative initialisation steps (usually just one)
             e.g. [1000, 1000, 500].
-        dtglitchmin: int
-            The minimum duration (in seconds) of a segment between two glitches
-            or a glitch and the start/end of the data
         nwalkers, ntemps: int
             Number of walkers and temperatures
         minCoverFreq, maxCoverFreq: float
@@ -394,12 +388,14 @@ class MCMCGlitchSearch(BaseSearchClass):
 
         """
 
-        logging.info(('Set-up MCMC search with {} glitches for model {} on'
-                      ' data {}').format(self.nglitch, self.label,
-                                         self.sftlabel))
+        logging.info(
+            'Set-up MCMC search for model {} on data {}'.format(
+                self.label, self.sftlabel))
         if os.path.isdir(outdir) is False:
             os.mkdir(outdir)
         self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
+        self.theta_prior['tstart'] = self.tstart
+        self.theta_prior['tend'] = self.tend
         self.unpack_input_theta()
         self.ndim = len(self.theta_keys)
         self.sft_filepath = self.sftdir+'/*_'+self.sftlabel+"*sft"
@@ -415,65 +411,43 @@ class MCMCGlitchSearch(BaseSearchClass):
 
     def inititate_search_object(self):
         logging.info('Setting up search object')
-        self.search = SemiCoherentGlitchSearch(
-            label=self.label, outdir=self.outdir, sftlabel=self.sftlabel,
-            sftdir=self.sftdir, tref=self.tref, tstart=self.tstart,
-            tend=self.tend, minCoverFreq=self.minCoverFreq,
+        self.search = ComputeFstat(
+            tref=self.tref, sftlabel=self.sftlabel,
+            sftdir=self.sftdir, minCoverFreq=self.minCoverFreq,
             maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
-            sun_ephem=self.sun_ephem, detector=self.detector,
-            nglitch=self.nglitch)
+            sun_ephem=self.sun_ephem, detector=self.detector)
 
     def logp(self, theta_vals, theta_prior, theta_keys, search):
-        if self.nglitch > 1:
-            ts = [self.tstart] + theta_vals[-self.nglitch:] + [self.tend]
-            if np.array_equal(ts, np.sort(ts)) is False:
-                return -np.inf
-            if any(np.diff(ts) < self.dtglitchmin):
-                return -np.inf
-
-        H = [self.Generic_lnprior(**theta_prior[key])(p) for p, key in
+        H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
              zip(theta_vals, theta_keys)]
         return np.sum(H)
 
     def logl(self, theta, search):
         for j, theta_i in enumerate(self.theta_idxs):
             self.fixed_theta[theta_i] = theta[j]
-        FS = search.compute_nglitch_fstat(*self.fixed_theta)
+        FS = search.run_computefstatistic_single_point(*self.fixed_theta)
         return FS
 
     def unpack_input_theta(self):
-        glitch_keys = ['delta_F0', 'delta_F1', 'tglitch']
-        full_glitch_keys = list(np.array(
-            [[gk]*self.nglitch for gk in glitch_keys]).flatten())
-        full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']+full_glitch_keys
+        full_theta_keys = ['tstart', 'tend', 'F0', 'F1', 'F2', 'Alpha',
+                           'Delta']
         full_theta_keys_copy = copy.copy(full_theta_keys)
 
-        glitch_symbols = ['$\delta f$', '$\delta \dot{f}$', r'$t_{glitch}$']
-        full_glitch_symbols = list(np.array(
-            [[gs]*self.nglitch for gs in glitch_symbols]).flatten())
-        full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
-                               r'$\delta$'] + full_glitch_symbols)
+        full_theta_symbols = ['_', '_', '$f$', '$\dot{f}$', '$\ddot{f}$',
+                              r'$\alpha$', r'$\delta$']
         self.theta_keys = []
         fixed_theta_dict = {}
         for key, val in self.theta_prior.iteritems():
             if type(val) is dict:
                 fixed_theta_dict[key] = 0
-                if key in glitch_keys:
-                    for i in range(self.nglitch):
-                        self.theta_keys.append(key)
-                else:
-                    self.theta_keys.append(key)
+                self.theta_keys.append(key)
             elif type(val) in [float, int, np.float64]:
                 fixed_theta_dict[key] = val
             else:
                 raise ValueError(
                     'Type {} of {} in theta not recognised'.format(
                         type(val), key))
-            if key in glitch_keys:
-                for i in range(self.nglitch):
-                    full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
-            else:
-                full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
+            full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
 
         if len(full_theta_keys_copy) > 0:
             raise ValueError(('Input dictionary `theta` is missing the'
@@ -489,13 +463,6 @@ class MCMCGlitchSearch(BaseSearchClass):
         self.theta_symbols = [self.theta_symbols[i] for i in idxs]
         self.theta_keys = [self.theta_keys[i] for i in idxs]
 
-        # Correct for number of glitches in the idxs
-        self.theta_idxs = np.array(self.theta_idxs)
-        while np.sum(self.theta_idxs[:-1] == self.theta_idxs[1:]) > 0:
-            for i, idx in enumerate(self.theta_idxs):
-                if idx in self.theta_idxs[:i]:
-                    self.theta_idxs[i] += 1
-
     def check_initial_points(self, p0):
         initial_priors = np.array([
             self.logp(p, self.theta_prior, self.theta_keys, self.search)
@@ -525,7 +492,8 @@ class MCMCGlitchSearch(BaseSearchClass):
             logpargs=(self.theta_prior, self.theta_keys, self.search),
             loglargs=(self.search,), betas=self.betas)
 
-        p0 = self.GenerateInitial()
+        p0 = self.generate_initial_p0()
+        p0 = self.apply_corrections_to_p0(p0)
         self.check_initial_points(p0)
 
         ninit_steps = len(self.nsteps) - 2
@@ -534,11 +502,12 @@ class MCMCGlitchSearch(BaseSearchClass):
                 j, ninit_steps, n))
             sampler.run_mcmc(p0, n)
 
-            fig, axes = self.PlotWalkers(sampler, symbols=self.theta_symbols)
+            fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols)
             fig.savefig('{}/{}_init_{}_walkers.png'.format(
                 self.outdir, self.label, j))
 
             p0 = self.get_new_p0(sampler, scatter_val=self.scatter_val)
+            p0 = self.apply_corrections_to_p0(p0)
             self.check_initial_points(p0)
             sampler.reset()
 
@@ -548,7 +517,7 @@ class MCMCGlitchSearch(BaseSearchClass):
             nburn+nprod))
         sampler.run_mcmc(p0, nburn+nprod)
 
-        fig, axes = self.PlotWalkers(sampler, symbols=self.theta_symbols)
+        fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols)
         fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label))
 
         samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
@@ -622,14 +591,14 @@ class MCMCGlitchSearch(BaseSearchClass):
             ax = axes[i][i]
             xlim = ax.get_xlim()
             s = samples[:, i]
-            prior = self.Generic_lnprior(**self.theta_prior[key])
+            prior = self.generic_lnprior(**self.theta_prior[key])
             x = np.linspace(s.min(), s.max(), 100)
             ax2 = ax.twinx()
             ax2.get_yaxis().set_visible(False)
             ax2.plot(x, [prior(xi) for xi in x], '-r')
             ax.set_xlim(xlim)
 
-    def Generic_lnprior(self, **kwargs):
+    def generic_lnprior(self, **kwargs):
         """ Return a lambda function of the pdf
 
         Parameters
@@ -679,7 +648,7 @@ class MCMCGlitchSearch(BaseSearchClass):
             logging.info("kwargs:", kwargs)
             raise ValueError("Print unrecognise distribution")
 
-    def GenerateRV(self, **kwargs):
+    def generate_rv(self, **kwargs):
         dist_type = kwargs.pop('type')
         if dist_type == "unif":
             return np.random.uniform(low=kwargs['lower'], high=kwargs['upper'])
@@ -694,8 +663,8 @@ class MCMCGlitchSearch(BaseSearchClass):
         else:
             raise ValueError("dist_type {} unknown".format(dist_type))
 
-    def PlotWalkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
-                    start=None, stop=None, draw_vline=None):
+    def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
+                     start=None, stop=None, draw_vline=None):
         """ Plot all the chains from a sampler """
 
         shape = sampler.chain.shape
@@ -725,38 +694,35 @@ class MCMCGlitchSearch(BaseSearchClass):
 
         return fig, axes
 
-    def _generate_scattered_p0(self, p):
+    def apply_corrections_to_p0(self, p0):
+        """ Apply any correction to the initial p0 values """
+        return p0
+
+    def generate_scattered_p0(self, p):
         """ Generate a set of p0s scattered about p """
-        p0 = [[p + scatter_val * p * np.random.randn(self.ndim)
+        p0 = [[p + self.scatter_val * p * np.random.randn(self.ndim)
                for i in xrange(self.nwalkers)]
               for j in xrange(self.ntemps)]
         return p0
 
-    def _sort_p0_times(self, p0):
-        p0 = np.array(p0)
-        p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:], axis=2)
-        return p0
-
-    def GenerateInitial(self):
+    def generate_initial_p0(self):
         """ Generate a set of init vals for the walkers """
 
         if type(self.theta_initial) == dict:
-            p0 = [[[self.GenerateRV(**self.theta_initial[key])
+            p0 = [[[self.generate_rv(**self.theta_initial[key])
                     for key in self.theta_keys]
                    for i in range(self.nwalkers)]
                   for j in range(self.ntemps)]
         elif self.theta_initial is None:
-            p0 = [[[self.GenerateRV(**self.theta_prior[key])
+            p0 = [[[self.generate_rv(**self.theta_prior[key])
                     for key in self.theta_keys]
                    for i in range(self.nwalkers)]
                   for j in range(self.ntemps)]
         elif len(self.theta_initial) == self.ndim:
-            p0 = self._generate_scattered_p0(self.theta_initial)
+            p0 = self.generate_scattered_p0(self.theta_initial)
         else:
             raise ValueError('theta_initial not understood')
 
-        if self.nglitch > 1:
-            p0 = self._sort_p0_times(p0)
         return p0
 
     def get_new_p0(self, sampler, scatter_val=1e-3):
@@ -780,8 +746,6 @@ class MCMCGlitchSearch(BaseSearchClass):
         p = pF[np.nanargmax(lnp)]
         p0 = self._generate_scattered_p0(p)
 
-        if self.nglitch > 1:
-            p0 = self._sort_p0_times(p0)
         return p0
 
     def get_save_data_dictionary(self):
@@ -923,6 +887,164 @@ class MCMCGlitchSearch(BaseSearchClass):
                     k, d[k], d[k+'_std']))
 
 
+class MCMCGlitchSearch(MCMCSearch):
+    """ MCMC search using the SemiCoherentGlitchSearch """
+    @initializer
+    def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref,
+                 tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1,
+                 nglitch=0, theta_initial=None, minCoverFreq=None,
+                 maxCoverFreq=None, scatter_val=1e-4, betas=None,
+                 detector=None, dtglitchmin=20*86400, earth_ephem=None,
+                 sun_ephem=None):
+        """
+        Parameters
+        label, outdir: str
+            A label and directory to read/write data from/to
+        sftlabel, sftdir: str
+            A label and directory in which to find the relevant sft file
+        theta_prior: dict
+            Dictionary of priors and fixed values for the search parameters.
+            For each parameters (key of the dict), if it is to be held fixed
+            the value should be the constant float, if it is be searched, the
+            value should be a dictionary of the prior.
+        theta_initial: dict, array, (None)
+            Either a dictionary of distribution about which to distribute the
+            initial walkers about, an array (from which the walkers will be
+            scattered by scatter_val, or  None in which case the prior is used.
+        nglitch: int
+            The number of glitches to allow
+        tref, tstart, tend: int
+            GPS seconds of the reference time, start time and end time
+        nsteps: list (m,)
+            List specifying the number of steps to take, the last two entries
+            give the nburn and nprod of the 'production' run, all entries
+            before are for iterative initialisation steps (usually just one)
+            e.g. [1000, 1000, 500].
+        dtglitchmin: int
+            The minimum duration (in seconds) of a segment between two glitches
+            or a glitch and the start/end of the data
+        nwalkers, ntemps: int
+            Number of walkers and temperatures
+        minCoverFreq, maxCoverFreq: float
+            Minimum and maximum instantaneous frequency which will be covered
+            over the SFT time span as passed to CreateFstatInput
+        earth_ephem, sun_ephem: str
+            Paths of the two files containing positions of Earth and Sun,
+            respectively at evenly spaced times, as passed to CreateFstatInput
+            If None defaults defined in BaseSearchClass will be used
+
+        """
+
+        logging.info(('Set-up MCMC glitch search with {} glitches for model {}'
+                      ' on data {}').format(self.nglitch, self.label,
+                                            self.sftlabel))
+        if os.path.isdir(outdir) is False:
+            os.mkdir(outdir)
+        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
+        self.unpack_input_theta()
+        self.ndim = len(self.theta_keys)
+        self.sft_filepath = self.sftdir+'/*_'+self.sftlabel+"*sft"
+        if earth_ephem is None:
+            self.earth_ephem = self.earth_ephem_default
+        if sun_ephem is None:
+            self.sun_ephem = self.sun_ephem_default
+
+        if args.clean and os.path.isfile(self.pickle_path):
+            os.rename(self.pickle_path, self.pickle_path+".old")
+
+        self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
+
+    def inititate_search_object(self):
+        logging.info('Setting up search object')
+        self.search = SemiCoherentGlitchSearch(
+            label=self.label, outdir=self.outdir, sftlabel=self.sftlabel,
+            sftdir=self.sftdir, tref=self.tref, tstart=self.tstart,
+            tend=self.tend, minCoverFreq=self.minCoverFreq,
+            maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
+            sun_ephem=self.sun_ephem, detector=self.detector,
+            nglitch=self.nglitch)
+
+    def logp(self, theta_vals, theta_prior, theta_keys, search):
+        if self.nglitch > 1:
+            ts = [self.tstart] + theta_vals[-self.nglitch:] + [self.tend]
+            if np.array_equal(ts, np.sort(ts)) is False:
+                return -np.inf
+            if any(np.diff(ts) < self.dtglitchmin):
+                return -np.inf
+
+        H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in
+             zip(theta_vals, theta_keys)]
+        return np.sum(H)
+
+    def logl(self, theta, search):
+        for j, theta_i in enumerate(self.theta_idxs):
+            self.fixed_theta[theta_i] = theta[j]
+        FS = search.compute_nglitch_fstat(*self.fixed_theta)
+        return FS
+
+    def unpack_input_theta(self):
+        glitch_keys = ['delta_F0', 'delta_F1', 'tglitch']
+        full_glitch_keys = list(np.array(
+            [[gk]*self.nglitch for gk in glitch_keys]).flatten())
+        full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']+full_glitch_keys
+        full_theta_keys_copy = copy.copy(full_theta_keys)
+
+        glitch_symbols = ['$\delta f$', '$\delta \dot{f}$', r'$t_{glitch}$']
+        full_glitch_symbols = list(np.array(
+            [[gs]*self.nglitch for gs in glitch_symbols]).flatten())
+        full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
+                               r'$\delta$'] + full_glitch_symbols)
+        self.theta_keys = []
+        fixed_theta_dict = {}
+        for key, val in self.theta_prior.iteritems():
+            if type(val) is dict:
+                fixed_theta_dict[key] = 0
+                if key in glitch_keys:
+                    for i in range(self.nglitch):
+                        self.theta_keys.append(key)
+                else:
+                    self.theta_keys.append(key)
+            elif type(val) in [float, int, np.float64]:
+                fixed_theta_dict[key] = val
+            else:
+                raise ValueError(
+                    'Type {} of {} in theta not recognised'.format(
+                        type(val), key))
+            if key in glitch_keys:
+                for i in range(self.nglitch):
+                    full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
+            else:
+                full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
+
+        if len(full_theta_keys_copy) > 0:
+            raise ValueError(('Input dictionary `theta` is missing the'
+                              'following keys: {}').format(
+                                  full_theta_keys_copy))
+
+        self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
+        self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
+        self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]
+
+        idxs = np.argsort(self.theta_idxs)
+        self.theta_idxs = [self.theta_idxs[i] for i in idxs]
+        self.theta_symbols = [self.theta_symbols[i] for i in idxs]
+        self.theta_keys = [self.theta_keys[i] for i in idxs]
+
+        # Correct for number of glitches in the idxs
+        self.theta_idxs = np.array(self.theta_idxs)
+        while np.sum(self.theta_idxs[:-1] == self.theta_idxs[1:]) > 0:
+            for i, idx in enumerate(self.theta_idxs):
+                if idx in self.theta_idxs[:i]:
+                    self.theta_idxs[i] += 1
+
+    def apply_corrections_to_p0(self, p0):
+        p0 = np.array(p0)
+        if self.nglitch > 1:
+            p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:],
+                                               axis=2)
+        return p0
+
+
 class GridGlitchSearch(BaseSearchClass):
     """ Gridded search using the SemiCoherentGlitchSearch """
     @initializer
diff --git a/tests/tests.py b/tests/tests.py
index 98a3c3c..176b75f 100644
--- a/tests/tests.py
+++ b/tests/tests.py
@@ -137,7 +137,7 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase):
         self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3)
 
 
-class TestMCMCGlitchSearch(unittest.TestCase):
+class TestMCMCSearch(unittest.TestCase):
     label = "MCMCTest"
     outdir = 'TestData'
 
@@ -165,13 +165,12 @@ class TestMCMCGlitchSearch(unittest.TestCase):
         Writer.make_data()
         predicted_FS = Writer.predict_fstat()
 
-        theta = {'delta_F0': 0, 'delta_F1': 0, 'tglitch': tend,
-                 'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)},
+        theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)},
                  'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-9*F1)},
                  'F2': F2, 'Alpha': Alpha, 'Delta': Delta}
 
-        search = pyfstat.MCMCGlitchSearch(
-            label=self.label, outdir=self.outdir, theta=theta, tref=tref,
+        search = pyfstat.MCMCSearch(
+            label=self.label, outdir=self.outdir, theta_prior=theta, tref=tref,
             sftlabel=self.label, sftdir=self.outdir,
             tstart=tstart, tend=tend, nsteps=[100, 100], nwalkers=100,
             ntemps=1)
@@ -181,7 +180,8 @@ class TestMCMCGlitchSearch(unittest.TestCase):
 
         print('Predicted twoF is {} while recovered is {}'.format(
                 predicted_FS, FS))
-        self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3)
+        self.assertTrue(
+            FS > predicted_FS or np.abs((FS-predicted_FS))/predicted_FS < 0.3)
 
 
 if __name__ == '__main__':
-- 
GitLab