From 9224e7414083b440ef0484ee1dce134fb6157e40 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Tue, 15 Nov 2016 17:29:43 +0100
Subject: [PATCH] Various improvements to generate pub. quality tables and
 plots

---
 examples/fully_coherent_search_using_MCMC.py |   2 +-
 examples/grided_frequency_search.py          |  64 +++++++
 examples/weak_signal_follow_up.py            |  57 +++---
 pyfstat.py                                   | 192 ++++++++++++-------
 4 files changed, 224 insertions(+), 91 deletions(-)
 create mode 100644 examples/grided_frequency_search.py

diff --git a/examples/fully_coherent_search_using_MCMC.py b/examples/fully_coherent_search_using_MCMC.py
index 7372591..7e866bc 100644
--- a/examples/fully_coherent_search_using_MCMC.py
+++ b/examples/fully_coherent_search_using_MCMC.py
@@ -18,7 +18,7 @@ theta_prior = {'F0': {'type': 'unif', 'lower': F0*(1-1e-6), 'upper': F0*(1+1e-6)
                'Delta': Delta
                }
 
-ntemps = 4
+ntemps = 3
 log10temperature_min = -1
 nwalkers = 100
 nsteps = [1000, 1000]
diff --git a/examples/grided_frequency_search.py b/examples/grided_frequency_search.py
new file mode 100644
index 0000000..5e4a423
--- /dev/null
+++ b/examples/grided_frequency_search.py
@@ -0,0 +1,64 @@
+import pyfstat
+import numpy as np
+import matplotlib.pyplot as plt
+
+plt.style.use('paper')
+
+F0 = 30.0
+F1 = 0
+F2 = 0
+Alpha = 1.0
+Delta = 1.5
+
+# Properties of the GW data
+sqrtSX = 1e-23
+tstart = 1000000000
+duration = 100*86400
+tend = tstart+duration
+tref = .5*(tstart+tend)
+
+depth = 70
+data_label = 'grided_frequency_depth_{:1.0f}'.format(depth)
+
+h0 = sqrtSX / depth
+
+data = pyfstat.Writer(
+    label=data_label, outdir='data', tref=tref,
+    tstart=tstart, F0=F0, F1=F1, F2=F2, duration=duration, Alpha=Alpha,
+    Delta=Delta, h0=h0, sqrtSX=sqrtSX)
+data.make_data()
+
+m = 0.001
+dF0 = np.sqrt(12*m)/(np.pi*duration)
+DeltaF0 = 800*dF0
+F0s = [F0-DeltaF0/2., F0+DeltaF0/2., dF0]
+F1s = [F1]
+F2s = [F2]
+Alphas = [Alpha]
+Deltas = [Delta]
+search = pyfstat.GridSearch(
+    'grided_frequency_search', 'data', 'data/*'+data_label+'*sft', F0s, F1s,
+    F2s, Alphas, Deltas, tref, tstart, tend)
+search.run()
+
+fig, ax = plt.subplots()
+xidx = search.keys.index('F0')
+frequencies = np.unique(search.data[:, xidx])
+twoF = search.data[:, -1]
+
+#mismatch = np.sign(x-F0)*(duration * np.pi * (x - F0))**2 / 12.0
+ax.plot(frequencies, twoF, 'k', lw=1)
+DeltaF = frequencies - F0
+sinc = np.sin(np.pi*DeltaF*duration)/(np.pi*DeltaF*duration)
+A = np.abs((np.max(twoF)-4)*sinc**2 + 4)
+ax.plot(frequencies, A, '-r', lw=1)
+ax.set_ylabel('$\widetilde{2\mathcal{F}}$')
+ax.set_xlabel('Frequency')
+ax.set_xlim(F0s[0], F0s[1])
+dF0 = np.sqrt(12*1)/(np.pi*duration)
+xticks = [F0-10*dF0, F0, F0+10*dF0]
+ax.set_xticks(xticks)
+xticklabels = ['$f_0 {-} 10\Delta f$', '$f_0$', '$f_0 {+} 10\Delta f$']
+ax.set_xticklabels(xticklabels)
+plt.tight_layout()
+fig.savefig('{}/{}_1D.png'.format(search.outdir, search.label), dpi=300)
diff --git a/examples/weak_signal_follow_up.py b/examples/weak_signal_follow_up.py
index 629290e..a6286aa 100644
--- a/examples/weak_signal_follow_up.py
+++ b/examples/weak_signal_follow_up.py
@@ -1,25 +1,25 @@
 import pyfstat
 
-# Define parameters of the Crab pulsar as an example
 F0 = 30.0
-F1 = -1e-10
+F1 = 0
 F2 = 0
-Alpha = 5e-3
-Delta = 6e-2
-tref = 362750407.0
+Alpha = 1.0
+Delta = 0.5
 
 # Properties of the GW data
 sqrtSX = 1e-23
 tstart = 1000000000
 duration = 100*86400
 tend = tstart+duration
+tref = .5*(tstart+tend)
 
-depth = 50
+depth = 70
+data_label = 'weak_signal_follow_up_depth_{:1.0f}'.format(depth)
 
 h0 = sqrtSX / depth
 
 data = pyfstat.Writer(
-    label='depth_{:1.0f}'.format(depth), outdir='data', tref=tref,
+    label=data_label, outdir='data', tref=tref,
     tstart=tstart, F0=F0, F1=F1, F2=F2, duration=duration, Alpha=Alpha,
     Delta=Delta, h0=h0, sqrtSX=sqrtSX)
 data.make_data()
@@ -29,35 +29,40 @@ twoF = data.predict_fstat()
 print 'Predicted twoF value: {}\n'.format(twoF)
 
 # Search
-theta_prior = {'F0': {'type': 'unif', 'lower': F0*(1-1e-4),
-                      'upper': F0*(1+1e-4)},
-               'F1': {'type': 'unif', 'lower': F1*(1+1e-2),
-                      'upper': F1*(1-1e-2)},
+theta_prior = {'F0': {'type': 'unif', 'lower': F0*(1-1e-6),
+                      'upper': F0*(1+1e-6)},
+               'F1': F1, #{'type': 'unif', 'lower': F1*(1+1e-2),
+                      #'upper': F1*(1-1e-2)},
                'F2': F2,
                'Alpha': {'type': 'unif', 'lower': Alpha-1e-2,
                          'upper': Alpha+1e-2},
-               'Delta': {'type': 'unif', 'lower': Delta-5e-2,
-                         'upper': Delta+5e-2},
+               'Delta': {'type': 'unif', 'lower': Delta-1e-2,
+                         'upper': Delta+1e-2},
                }
 
-ntemps = 1
+ntemps = 3
 log10temperature_min = -1
-nwalkers = 100
-run_setup = [(1000, 50),
-             (1000, 30),
-             (1000, 20),
-             (1000, 15),
-             (1000, 10),
-             (1000, 5),
-             (1000, 1),
-             ((1000, 1000), 1, True)]
+nwalkers = 200
+scatter_val = 1e-10
+
+stages = 7
+steps = 100
+#run_setup = [(steps, 2**i) for i in reversed(range(1, stages+1))]
+#run_setup.append(((steps, steps), 1, True))
+run_setup = [(steps, 80),
+             (steps, 40),
+             (steps, 20),
+             (steps, 10),
+             (steps, 5),
+             ((steps, steps), 1, False)]
 
 mcmc = pyfstat.MCMCFollowUpSearch(
     label='weak_signal_follow_up', outdir='data',
-    sftfilepath='data/*depth*sft', theta_prior=theta_prior, tref=tref,
+    sftfilepath='data/*'+data_label+'*sft', theta_prior=theta_prior, tref=tref,
     minStartTime=tstart, maxStartTime=tend, nwalkers=nwalkers,
-    ntemps=ntemps, log10temperature_min=log10temperature_min)
+    ntemps=ntemps, log10temperature_min=log10temperature_min,
+    scatter_val=scatter_val)
 mcmc.run(run_setup)
 mcmc.plot_corner(add_prior=True)
 mcmc.print_summary()
-mcmc.generate_loudest()
+#mcmc.generate_loudest()
diff --git a/pyfstat.py b/pyfstat.py
index 04515a0..e558e31 100755
--- a/pyfstat.py
+++ b/pyfstat.py
@@ -70,6 +70,26 @@ stream_handler.setFormatter(logging.Formatter(
 logger.addHandler(stream_handler)
 
 
+def round_to_n(x, n):
+    if not x:
+        return 0
+    power = -int(np.floor(np.log10(abs(x)))) + (n - 1)
+    factor = (10 ** power)
+    return round(x * factor) / factor
+
+
+def texify_float(x, d=1):
+    x = round_to_n(x, d)
+    if 0.01 < abs(x) < 100:
+        return str(x)
+    else:
+        power = int(np.floor(np.log10(abs(x))))
+        stem = np.round(x / 10**power, d)
+        if d == 1:
+            stem = int(stem)
+        return r'${}{{\times}}10^{{{}}}$'.format(stem, power)
+
+
 def initializer(func):
     """ Decorator function to automatically assign the parameters to self """
     names, varargs, keywords, defaults = inspect.getargspec(func)
@@ -902,7 +922,7 @@ class MCMCSearch(BaseSearchClass):
             pass
         return sampler
 
-    def run(self, proposal_scale_factor=2):
+    def run(self, proposal_scale_factor=2, **kwargs):
 
         self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
         if self.old_data_is_okay_to_use is True:
@@ -936,9 +956,10 @@ class MCMCSearch(BaseSearchClass):
             if self.ntemps > 1:
                 logging.info("Tswap acceptance fraction: {}"
                              .format(sampler.tswap_acceptance_fraction))
-            fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols)
+            fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
+                                          **kwargs)
             fig.savefig('{}/{}_init_{}_walkers.png'.format(
-                self.outdir, self.label, j))
+                self.outdir, self.label, j), dpi=200)
 
             p0 = self.get_new_p0(sampler)
             p0 = self.apply_corrections_to_p0(p0)
@@ -960,8 +981,9 @@ class MCMCSearch(BaseSearchClass):
                          .format(sampler.tswap_acceptance_fraction))
 
         fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
-                                      burnin_idx=nburn)
-        fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label))
+                                      burnin_idx=nburn, **kwargs)
+        fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
+                    dpi=200)
 
         samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
         lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1))
@@ -1200,7 +1222,8 @@ class MCMCSearch(BaseSearchClass):
 
     def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
                      lw=0.1, burnin_idx=None, add_det_stat_burnin=False,
-                     fig=None, axes=None, xoffset=0, plot_det_stat=True):
+                     fig=None, axes=None, xoffset=0, plot_det_stat=True,
+                     context='classic'):
         """ Plot all the chains from a sampler """
 
         shape = sampler.chain.shape
@@ -1216,7 +1239,7 @@ class MCMCSearch(BaseSearchClass):
                                   "available range").format(temp))
             chain = sampler.chain[temp, :, :, :]
 
-        with plt.style.context(('classic')):
+        with plt.style.context((context)):
             if fig is None and axes is None:
                 fig = plt.figure(figsize=(8, 4*ndim))
                 ax = fig.add_subplot(ndim+1, 1, 1)
@@ -1227,6 +1250,8 @@ class MCMCSearch(BaseSearchClass):
             if ndim > 1:
                 for i in range(ndim):
                     axes[i].ticklabel_format(useOffset=False, axis='y')
+                    if i < ndim:
+                        axes[i].set_xticklabels([])
                     cs = chain[:, :, i].T
                     if burnin_idx:
                         axes[i].plot(xoffset+idxs[:burnin_idx],
@@ -1247,31 +1272,33 @@ class MCMCSearch(BaseSearchClass):
                 if symbols:
                     axes[0].set_ylabel(symbols[0])
 
-        if len(axes) == ndim:
-            axes.append(fig.add_subplot(ndim+1, 1, ndim+1))
-
-        if plot_det_stat:
-            lnl = sampler.lnlikelihood[temp, :, :]
-            if burnin_idx and add_det_stat_burnin:
-                burn_in_vals = lnl[:, :burnin_idx].flatten()
-                axes[-1].hist(burn_in_vals[~np.isnan(burn_in_vals)], bins=50,
-                              histtype='step', color='r')
-            else:
-                burn_in_vals = []
-            prod_vals = lnl[:, burnin_idx:].flatten()
-            axes[-1].hist(prod_vals[~np.isnan(prod_vals)], bins=50,
-                          histtype='step', color='k')
-            if self.BSGL:
-                axes[-1].set_xlabel(r'$\mathcal{B}_\mathrm{S/GL}$')
-            else:
-                axes[-1].set_xlabel(r'$2\mathcal{F}$')
-            combined_vals = np.append(burn_in_vals, prod_vals)
-            if len(combined_vals) > 0:
-                minv = np.min(combined_vals)
-                maxv = np.max(combined_vals)
-                Range = abs(maxv-minv)
-                axes[-1].set_xlim(minv-0.1*Range, maxv+0.1*Range)
+            if len(axes) == ndim:
+                axes.append(fig.add_subplot(ndim+1, 1, ndim+1))
 
+            if plot_det_stat:
+                lnl = sampler.lnlikelihood[temp, :, :]
+                if burnin_idx and add_det_stat_burnin:
+                    burn_in_vals = lnl[:, :burnin_idx].flatten()
+                    axes[-1].hist(burn_in_vals[~np.isnan(burn_in_vals)], bins=50,
+                                  histtype='step', color='r')
+                else:
+                    burn_in_vals = []
+                prod_vals = lnl[:, burnin_idx:].flatten()
+                axes[-1].hist(prod_vals[~np.isnan(prod_vals)], bins=50,
+                              histtype='step', color='k')
+                if self.BSGL:
+                    axes[-1].set_xlabel(r'$\mathcal{B}_\mathrm{S/GL}$')
+                else:
+                    axes[-1].set_xlabel(r'$\widetilde{2\mathcal{F}}$')
+                axes[-1].set_ylabel(r'$\textrm{Counts}$')
+                combined_vals = np.append(burn_in_vals, prod_vals)
+                if len(combined_vals) > 0:
+                    minv = np.min(combined_vals)
+                    maxv = np.max(combined_vals)
+                    Range = abs(maxv-minv)
+                    axes[-1].set_xlim(minv-0.1*Range, maxv+0.1*Range)
+
+            axes[-2].set_xlabel(r'$\textrm{Number of steps}$', labelpad=0.1)
         return fig, axes
 
     def apply_corrections_to_p0(self, p0):
@@ -1970,6 +1997,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             return .5*(prior[key]['upper'] + prior[key]['lower'])
 
     def get_number_of_templates_estimate(self, nsegs):
+        """ Returns V, Vsky, Vf estimated from the super-sky metric """
         tboundaries = np.linspace(self.minStartTime, self.maxStartTime,
                                   nsegs+1)
         if 'Alpha' in self.theta_keys:
@@ -2016,30 +2044,17 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
                     SSkyMetric.semi_rssky_metric.data[:2, :2]))
                 sqrtdetG_F = np.sqrt(np.linalg.det(
                     SSkyMetric.semi_rssky_metric.data[2:, 2:]))
-                return '{:1.0e} = 1/2 x {:1.0e} x {:1.0e}'.format(
-                        .5*sqrtdetG_SKY*sqrtdetG_F*DeltaOmega*DeltaF1*DeltaF0,
-                        sqrtdetG_SKY*DeltaOmega,
+                return (.5*sqrtdetG_SKY*sqrtdetG_F*DeltaOmega*DeltaF1*DeltaF0,
+                        .5*sqrtdetG_SKY*DeltaOmega,
                         sqrtdetG_F*DeltaF1*DeltaF0)
             except RuntimeError:
                 return 'N/A'
         elif self.theta_keys == ['F0', 'F1']:
             return 'N/A'
 
-    def run(self, run_setup, proposal_scale_factor=2):
-        """ Run the follow-up with the given run_setup
-
-        Parameters
-        ----------
-        run_setup: list of tuples
-
-        """
-
-        self.nsegs = 1
-        self.inititate_search_object()
-
-        logging.info('Using run-setup as follow:')
-        logging.info('Stage | nburn | nprod | nsegs | resetp0 |'
-                     '# templates = # sky x # Freq')
+    def init_run_setup(self, run_setup, log_table=True, gen_tex_table=True):
+        logging.info('Calculating the number of templates for this setup..')
+        number_of_templates = []
         for i, rs in enumerate(run_setup):
             rs = list(rs)
             if len(rs) == 2:
@@ -2047,17 +2062,67 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             if np.shape(rs[0]) == ():
                 rs[0] = (rs[0], 0)
             run_setup[i] = rs
-            N = self.get_number_of_templates_estimate(rs[1])
-            logging.info('{} | {} | {} | {} | {} | {}'.format(
-                str(i).ljust(5), str(rs[0][0]).ljust(5),
-                str(rs[0][1]).ljust(5), str(rs[1]).ljust(5),
-                str(rs[2]).ljust(7), N))
+            number_of_templates.append(
+                    self.get_number_of_templates_estimate(rs[1]))
+
+        if log_table:
+            logging.info('Using run-setup as follow:')
+            logging.info('Stage | nburn | nprod | nsegs | resetp0 |'
+                         '# templates = # sky x # Freq')
+            for i, rs in enumerate(run_setup):
+                if number_of_templates[i] != 'N/A':
+                    vtext = '{:1.0e} = {:1.0e} x {:1.0e}'.format(
+                            *number_of_templates[i])
+                else:
+                    vtext = number_of_templates[i]
+                logging.info('{} | {} | {} | {} | {} | {}'.format(
+                    str(i).ljust(5), str(rs[0][0]).ljust(5),
+                    str(rs[0][1]).ljust(5), str(rs[1]).ljust(5),
+                    str(rs[2]).ljust(7), vtext))
+
+        if gen_tex_table:
+            filename = '{}/{}_run_setup.tex'.format(self.outdir, self.label)
+            with open(filename, 'w+') as f:
+                f.write(r'\begin{tabular}{c|cccccc}' + '\n')
+                f.write(r'Stage & $\Nseg$ & $\Tcoh^{\rm days}$ &'
+                        r'$\Nsteps$ & $\V$ & $\Vsky$ & $\Vf$ \\ \hline'
+                        '\n')
+                for i, rs in enumerate(run_setup):
+                    Tcoh = (self.maxStartTime - self.minStartTime) / rs[1] / 86400
+                    line = r'{} & {} & {} & {} & {} & {} & {} \\' + '\n'
+                    if number_of_templates[i] == 'N/A':
+                        V = Vsky = Vf = 'N/A'
+                    else:
+                        V, Vsky, Vf = number_of_templates[i]
+                    if rs[0][-1] == 0:
+                        nsteps = rs[0][0]
+                    else:
+                        nsteps = '{},{}'.format(*rs[0])
+                    line = line.format(i, rs[1], Tcoh, nsteps,
+                                       texify_float(V), texify_float(Vsky),
+                                       texify_float(Vf))
+                    f.write(line)
+                f.write(r'\end{tabular}' + '\n')
 
         if args.setup_only:
             logging.info("Exit as requested by setup_only flag")
             sys.exit()
+        else:
+            return run_setup
+
+    def run(self, run_setup, proposal_scale_factor=2, **kwargs):
+        """ Run the follow-up with the given run_setup
+
+        Parameters
+        ----------
+        run_setup: list of tuples
+
+        """
+
+        self.nsegs = 1
+        self.inititate_search_object()
+        self.run_setup = self.init_run_setup(run_setup)
 
-        self.run_setup = run_setup
         self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
         if self.old_data_is_okay_to_use is True:
             logging.warning('Using saved data from {}'.format(
@@ -2109,18 +2174,19 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
 
             fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols,
                                           fig=fig, axes=axes, burnin_idx=nburn,
-                                          xoffset=nsteps_total)
-            yvals = axes[0].get_ylim()
-            axes[0].annotate(
-                r'$T_{{\rm coh}}^{{\rm (days)}}{{=}}{:1.1f}$'.format(Tcoh),
-                xy=(nsteps_total, yvals[0]*(1+1e-2*(yvals[1]-yvals[0])/yvals[1])),
-                fontsize=6)
+                                          xoffset=nsteps_total, **kwargs)
+            #yvals = axes[0].get_ylim()
+            #axes[0].annotate(
+            #    #r'$T_{{\rm coh}}^{{\rm (days)}}{{=}}{:1.1f}$'.format(Tcoh),
+            #    r'{}'.format(j),
+            #    xy=(nsteps_total, yvals[0]*(1+1e-2*(yvals[1]-yvals[0])/yvals[1])),
+            #    fontsize=6)
             for ax in axes[:-1]:
                 ax.axvline(nsteps_total, color='k', ls='--')
             nsteps_total += nburn+nprod
 
             fig.savefig('{}/{}_walkers.png'.format(
-                self.outdir, self.label), dpi=600)
+                self.outdir, self.label), dpi=200)
 
         samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
         lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1))
@@ -2174,8 +2240,6 @@ class MCMCTransientSearch(MCMCSearch):
 
         self.theta_keys = []
         fixed_theta_dict = {}
-        self.theta_prior.pop('tstart')
-        self.theta_prior.pop('tend')
         for key, val in self.theta_prior.iteritems():
             if type(val) is dict:
                 fixed_theta_dict[key] = 0
-- 
GitLab