From a42f0b0a46191f26c1114c7ec104b885ad5b9190 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Wed, 11 Oct 2017 09:31:08 +0200
Subject: [PATCH] Minor improvements to user interface

- Remove minStartTime, maxStartTime and outdir as default arguments
- Adds notes ot documentation on which arguments are optional
- Change log-level on some command line calls that dont' matter
- Reorganise the tests performed after loading the data
- If minStartTime and maxStartTime are None, set them using
  SFT_timestamps
- Remove default labels from plot_twoF_cumulative
- If add_pfs is called, call generate_loudest automatically
- Save min/maxStartTime in pickle and load if required
---
 pyfstat/core.py                | 39 +++++++++++------
 pyfstat/mcmc_based_searches.py | 80 +++++++++++++++++++++++-----------
 2 files changed, 79 insertions(+), 40 deletions(-)

diff --git a/pyfstat/core.py b/pyfstat/core.py
index d54634d..a508765 100755
--- a/pyfstat/core.py
+++ b/pyfstat/core.py
@@ -423,12 +423,10 @@ class ComputeFstat(BaseSearchClass):
             constraints.minStartTime = lal.LIGOTimeGPS(self.minStartTime)
         if self.maxStartTime:
             constraints.maxStartTime = lal.LIGOTimeGPS(self.maxStartTime)
-
         logging.info('Loading data matching pattern {}'.format(
                      self.sftfilepattern))
         SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepattern, constraints)
-        detector_names = list(set([d.header.name for d in SFTCatalog.data]))
-        self.detector_names = detector_names
+
         SFT_timestamps = [d.header.epoch for d in SFTCatalog.data]
         self.SFT_timestamps = [float(s) for s in SFT_timestamps]
         if len(SFT_timestamps) == 0:
@@ -440,21 +438,33 @@ class ComputeFstat(BaseSearchClass):
                 plot_hist(SFT_timestamps, height=5, bincount=50)
             except ImportError:
                 pass
-        if len(detector_names) == 0:
-            raise ValueError('No data loaded.')
-        logging.info('Loaded {} data files from detectors {}'.format(
-            len(SFT_timestamps), detector_names))
+
         cl_tconv1 = 'lalapps_tconvert {}'.format(int(SFT_timestamps[0]))
-        output = helper_functions.run_commandline(cl_tconv1)
+        output = helper_functions.run_commandline(cl_tconv1,
+                                                  log_level=logging.DEBUG)
         tconvert1 = output.rstrip('\n')
         cl_tconv2 = 'lalapps_tconvert {}'.format(int(SFT_timestamps[-1]))
-        output = helper_functions.run_commandline(cl_tconv2)
+        output = helper_functions.run_commandline(cl_tconv2,
+                                                  log_level=logging.DEBUG)
         tconvert2 = output.rstrip('\n')
         logging.info('Data spans from {} ({}) to {} ({})'.format(
             int(SFT_timestamps[0]),
             tconvert1,
             int(SFT_timestamps[-1]),
             tconvert2))
+
+        if self.minStartTime is None:
+            self.minStartTime = int(SFT_timestamps[0])
+        if self.maxStartTime is None:
+            self.maxStartTime = int(SFT_timestamps[-1])
+
+        detector_names = list(set([d.header.name for d in SFTCatalog.data]))
+        self.detector_names = detector_names
+        if len(detector_names) == 0:
+            raise ValueError('No data loaded.')
+        logging.info('Loaded {} data files from detectors {}'.format(
+            len(SFT_timestamps), detector_names))
+
         return SFTCatalog
 
     def init_computefstatistic_single_point(self):
@@ -735,7 +745,7 @@ class ComputeFstat(BaseSearchClass):
 
     def plot_twoF_cumulative(self, label, outdir, add_pfs=False, N=15,
                              injectSources=None, ax=None, c='k', savefig=True,
-                             title=None, **kwargs):
+                             title=None, plt_label=None, **kwargs):
         """ Plot the twoF value cumulatively
 
         Parameters
@@ -753,8 +763,8 @@ class ComputeFstat(BaseSearchClass):
             Colour
         savefig : bool
             If true, save the figure in outdir
-        title: str
-            Figure title
+        title, plt_label: str
+            Figure title and label
 
         Returns
         -------
@@ -775,7 +785,7 @@ class ComputeFstat(BaseSearchClass):
             pfs_input = None
 
         taus, twoFs = self.calculate_twoF_cumulative(**kwargs)
-        ax.plot(taus/86400., twoFs, label='All detectors', color=c)
+        ax.plot(taus/86400., twoFs, label=plt_label, color=c)
         if len(self.detector_names) > 1:
             detector_names = self.detector_names
             detectors = self.detectors
@@ -819,7 +829,8 @@ class ComputeFstat(BaseSearchClass):
         else:
             ax.set_ylabel(r'$\widetilde{2\mathcal{F}}_{\rm cumulative}$')
         ax.set_xlim(0, taus[-1]/86400)
-        ax.legend(frameon=False, loc=2, fontsize=6)
+        if plt_label:
+            ax.legend(frameon=False, loc=2, fontsize=6)
         if title:
             ax.set_title(title)
         if savefig:
diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index 06445f8..277c5cb 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -26,52 +26,55 @@ class MCMCSearch(core.BaseSearchClass):
 
     Parameters
     ----------
-    label, outdir: str
-        A label and directory to read/write data from/to
     theta_prior: dict
         Dictionary of priors and fixed values for the search parameters.
         For each parameters (key of the dict), if it is to be held fixed
         the value should be the constant float, if it is be searched, the
         value should be a dictionary of the prior.
     tref, minStartTime, maxStartTime: int
-        GPS seconds of the reference time, start time and end time
-    sftfilepattern: str
+        GPS seconds of the reference time, start time and end time. While tref
+        is requirede, minStartTime and maxStartTime default to None in which
+        case all available data is used.
+    label, outdir: str
+        A label and output directory (optional, defaults is `'data'`) to
+        name files
+    sftfilepattern: str, optional
         Pattern to match SFTs using wildcards (*?) and ranges [0-9];
         mutiple patterns can be given separated by colons.
-    detectors: str
+    detectors: str, optional
         Two character reference to the detectors to use, specify None for no
         contraint and comma separate for multiple references.
-    nsteps: list (2,)
+    nsteps: list (2,), optional
         Number of burn-in and production steps to take, [nburn, nprod]. See
         `pyfstat.MCMCSearch.setup_initialisation()` for details on adding
         initialisation steps.
-    nwalkers, ntemps: int,
+    nwalkers, ntemps: int, optional
         The number of walkers and temperates to use in the parallel
         tempered PTSampler.
-    log10beta_min float < 0
+    log10beta_min float < 0, optional
         The  log_10(beta) value, if given the set of betas passed to PTSampler
         are generated from `np.logspace(0, log10beta_min, ntemps)` (given
         in descending order to emcee).
-    theta_initial: dict, array, (None)
+    theta_initial: dict, array, optional
         A dictionary of distribution about which to distribute the
         initial walkers about
-    rhohatmax: float,
+    rhohatmax: float, optional
         Upper bound for the SNR scale parameter (required to normalise the
         Bayes factor) - this needs to be carefully set when using the
         evidence.
-    binary: bool
+    binary: bool, optional
         If true, search over binary parameters
-    BSGL: bool
+    BSGL: bool, optional
         If true, use the BSGL statistic
-    SSBPrec: int
+    SSBPrec: int, optional
         SSBPrec (SSB precision) to use when calling ComputeFstat
-    minCoverFreq, maxCoverFreq: float
+    minCoverFreq, maxCoverFreq: float, optional
         Minimum and maximum instantaneous frequency which will be covered
         over the SFT time span as passed to CreateFstatInput
-    injectSources: dict
+    injectSources: dict, optional
         If given, inject these properties into the SFT files before running
         the search
-    assumeSqrtSX: float
+    assumeSqrtSX: float, optional
         Don't estimate noise-floors, but assume (stationary) per-IFO sqrt{SX}
 
     Attributes
@@ -99,9 +102,9 @@ class MCMCSearch(core.BaseSearchClass):
     transform_dictionary = {}
 
     @helper_functions.initializer
-    def __init__(self, label, outdir, theta_prior, tref, minStartTime,
-                 maxStartTime, sftfilepattern=None, detectors=None,
-                 nsteps=[100, 100], nwalkers=100, ntemps=1,
+    def __init__(self, theta_prior, tref, label, outdir='data',
+                 minStartTime=None, maxStartTime=None, sftfilepattern=None,
+                 detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1,
                  log10beta_min=-5, theta_initial=None,
                  rhohatmax=1000, binary=False, BSGL=False,
                  SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
@@ -151,6 +154,10 @@ class MCMCSearch(core.BaseSearchClass):
             minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
             binary=self.binary, injectSources=self.injectSources,
             assumeSqrtSX=self.assumeSqrtSX, SSBprec=self.SSBprec)
+        if self.minStartTime is None:
+            self.minStartTime = self.search.minStartTime
+        if self.maxStartTime is None:
+            self.maxStartTime = self.search.maxStartTime
 
     def logp(self, theta_vals, theta_prior, theta_keys, search):
         H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
@@ -830,6 +837,9 @@ class MCMCSearch(core.BaseSearchClass):
             if key not in d:
                 d[key] = val
 
+        if 'add_pfs' in kwargs:
+            self.generate_loudest()
+
         if hasattr(self, 'search') is False:
             self._initiate_search_object()
         if self.binary is False:
@@ -1165,7 +1175,8 @@ class MCMCSearch(core.BaseSearchClass):
                  ntemps=self.ntemps, theta_keys=self.theta_keys,
                  theta_prior=self.theta_prior,
                  log10beta_min=self.log10beta_min,
-                 BSGL=self.BSGL)
+                 BSGL=self.BSGL, minStartTime=self.minStartTime,
+                 maxStartTime=self.maxStartTime)
         return d
 
     def _save_data(self, sampler, samples, lnprobs, lnlikes, all_lnlikelihood):
@@ -1212,6 +1223,11 @@ class MCMCSearch(core.BaseSearchClass):
         old_d.pop('lnlikes')
         old_d.pop('all_lnlikelihood')
 
+        for key in 'minStartTime', 'maxStartTime':
+            if new_d[key] is None:
+                new_d[key] = old_d[key]
+                setattr(self, key, new_d[key])
+
         mod_keys = []
         for key in new_d.keys():
             if key in old_d:
@@ -1569,9 +1585,9 @@ class MCMCGlitchSearch(MCMCSearch):
             )
 
     @helper_functions.initializer
-    def __init__(self, label, outdir, theta_prior, tref, minStartTime,
-                 maxStartTime, sftfilepattern=None, detectors=None,
-                 nsteps=[100, 100], nwalkers=100, ntemps=1,
+    def __init__(self, theta_prior, tref, label, outdir='data',
+                 minStartTime=None, maxStartTime=None, sftfilepattern=None,
+                 detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1,
                  log10beta_min=-5, theta_initial=None,
                  rhohatmax=1000, binary=False, BSGL=False,
                  SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
@@ -1610,6 +1626,10 @@ class MCMCGlitchSearch(MCMCSearch):
             minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
             detectors=self.detectors, BSGL=self.BSGL, nglitch=self.nglitch,
             theta0_idx=self.theta0_idx, injectSources=self.injectSources)
+        if self.minStartTime is None:
+            self.minStartTime = self.search.minStartTime
+        if self.maxStartTime is None:
+            self.maxStartTime = self.search.maxStartTime
 
     def logp(self, theta_vals, theta_prior, theta_keys, search):
         if self.nglitch > 1:
@@ -1778,9 +1798,9 @@ class MCMCSemiCoherentSearch(MCMCSearch):
     """
 
     @helper_functions.initializer
-    def __init__(self, label, outdir, theta_prior, tref, minStartTime,
-                 maxStartTime, sftfilepattern=None, detectors=None,
-                 nsteps=[100, 100], nwalkers=100, ntemps=1,
+    def __init__(self, theta_prior, tref, label, outdir='data',
+                 minStartTime=None, maxStartTime=None, sftfilepattern=None,
+                 detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1,
                  log10beta_min=-5, theta_initial=None,
                  rhohatmax=1000, binary=False, BSGL=False,
                  SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
@@ -1830,6 +1850,10 @@ class MCMCSemiCoherentSearch(MCMCSearch):
             maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
             maxCoverFreq=self.maxCoverFreq, detectors=self.detectors,
             injectSources=self.injectSources, assumeSqrtSX=self.assumeSqrtSX)
+        if self.minStartTime is None:
+            self.minStartTime = self.search.minStartTime
+        if self.maxStartTime is None:
+            self.maxStartTime = self.search.maxStartTime
 
     def logp(self, theta_vals, theta_prior, theta_keys, search):
         H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
@@ -2144,6 +2168,10 @@ class MCMCTransientSearch(MCMCSearch):
             minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
             BSGL=self.BSGL, binary=self.binary,
             injectSources=self.injectSources)
+        if self.minStartTime is None:
+            self.minStartTime = self.search.minStartTime
+        if self.maxStartTime is None:
+            self.maxStartTime = self.search.maxStartTime
 
     def logl(self, theta, search):
         for j, theta_i in enumerate(self.theta_idxs):
-- 
GitLab