Skip to content
Snippets Groups Projects
Select Git revision
  • 5c22976b9c4f1acba793d73875a11afd0991e9c0
  • master default
2 results

pykat_LLO.py

Blame
  • Forked from finesse / pykat
    Source project has a limited visibility.
    grid_based_searches.py 46.97 KiB
    """ Searches using grid-based methods """
    from __future__ import division, absolute_import, print_function
    
    import os
    import logging
    import itertools
    from collections import OrderedDict
    import datetime
    import getpass
    import socket
    
    import numpy as np
    import matplotlib
    import matplotlib.pyplot as plt
    from scipy.misc import logsumexp
    
    import pyfstat.helper_functions as helper_functions
    from pyfstat.core import (BaseSearchClass, ComputeFstat,
                              SemiCoherentGlitchSearch, SemiCoherentSearch, tqdm,
                              args, read_par)
    import lalpulsar
    import lal
    
    
    class GridSearch(BaseSearchClass):
        """ Gridded search using ComputeFstat """
        tex_labels = {'F0': '$f$', 'F1': '$\dot{f}$', 'F2': '$\ddot{f}$',
                      'Alpha': r'$\alpha$', 'Delta': r'$\delta$'}
        tex_labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$', 'F2': '$-\ddot{f}_0$',
                       'Alpha': r'$-\alpha_0$', 'Delta': r'$-\delta_0$'}
        search_labels = ['minStartTime', 'maxStartTime', 'F0s', 'F1s', 'F2s',
                         'Alphas', 'Deltas']
    
        @helper_functions.initializer
        def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
                     Deltas, tref=None, minStartTime=None, maxStartTime=None,
                     nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None,
                     detectors=None, SSBprec=None, injectSources=None,
                     input_arrays=False, assumeSqrtSX=None):
            """
            Parameters
            ----------
            label, outdir: str
                A label and directory to read/write data from/to
            sftfilepattern: str
                Pattern to match SFTs using wildcards (*?) and ranges [0-9];
                mutiple patterns can be given separated by colons.
            F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple
                Length 3 tuple describing the grid for each parameter, e.g
                [F0min, F0max, dF0], for a fixed value simply give [F0]. Unless
                input_arrays == True, then these are the values to search at.
            tref, minStartTime, maxStartTime: int
                GPS seconds of the reference time, start time and end time
            input_arrays: bool
                if true, use the F0s, F1s, etc as is
    
            For all other parameters, see `pyfstat.ComputeFStat` for details
    
            Note: if a large number of grid points are used, checks against cached
            data may be slow as the array is loaded into memory. To avoid this, run
            with the `clean` option which uses a generator instead.
            """
    
            if os.path.isdir(outdir) is False:
                os.mkdir(outdir)
            self.set_out_file()
            self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
            self.search_keys = [x+'s' for x in self.keys[2:]]
            for k in self.search_keys:
                setattr(self, k, np.atleast_1d(getattr(self, k)))
    
        def inititate_search_object(self):
            logging.info('Setting up search object')
            if self.nsegs == 1:
                self.search = ComputeFstat(
                    tref=self.tref, sftfilepattern=self.sftfilepattern,
                    minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
                    detectors=self.detectors,
                    minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
                    BSGL=self.BSGL, SSBprec=self.SSBprec,
                    injectSources=self.injectSources,
                    assumeSqrtSX=self.assumeSqrtSX)
                self.search.get_det_stat = self.search.get_fullycoherent_twoF
            else:
                self.search = SemiCoherentSearch(
                    label=self.label, outdir=self.outdir, tref=self.tref,
                    nsegs=self.nsegs, sftfilepattern=self.sftfilepattern,
                    BSGL=self.BSGL, minStartTime=self.minStartTime,
                    maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
                    maxCoverFreq=self.maxCoverFreq, detectors=self.detectors,
                    injectSources=self.injectSources)
    
                def cut_out_tstart_tend(*vals):
                    return self.search.get_semicoherent_twoF(*vals[2:])
                self.search.get_det_stat = cut_out_tstart_tend
    
        def get_array_from_tuple(self, x):
            if len(x) == 1:
                return np.array(x)
            elif len(x) == 3 and self.input_arrays is False:
                return np.arange(x[0], x[1], x[2])
            else:
                logging.info('Using tuple as is')
                return np.array(x)
    
        def get_input_data_array(self):
            logging.info("Generating input data array")
            coord_arrays = []
            for sl in self.search_labels:
                coord_arrays.append(
                    self.get_array_from_tuple(np.atleast_1d(getattr(self, sl))))
            self.coord_arrays = coord_arrays
            self.total_iterations = np.prod([len(ca) for ca in coord_arrays])
    
            if args.clean is False:
                input_data = []
                for vals in itertools.product(*coord_arrays):
                        input_data.append(vals)
                self.input_data = np.array(input_data)
    
        def check_old_data_is_okay_to_use(self):
            if args.clean:
                return False
            if os.path.isfile(self.out_file) is False:
                logging.info(
                    'No old data found in "{:s}", continuing with grid search'
                    .format(self.out_file))
                return False
            if self.sftfilepattern is not None:
                oldest_sft = min([os.path.getmtime(f) for f in
                                  self._get_list_of_matching_sfts()])
                if os.path.getmtime(self.out_file) < oldest_sft:
                    logging.info('Search output data outdates sft files,'
                                 + ' continuing with grid search')
                    return False
    
            data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
            if np.all(data[:, 0: len(self.coord_arrays)] ==
                      self.input_data[:, 0:len(self.coord_arrays)]):
                logging.info(
                    'Old data found in "{:s}" with matching input, no search '
                    'performed'.format(self.out_file))
                return data
            else:
                logging.info(
                    'Old data found in "{:s}", input differs, continuing with '
                    'grid search'.format(self.out_file))
                return False
            return False
    
        def run(self, return_data=False):
            self.get_input_data_array()
    
            if args.clean:
                iterable = itertools.product(*self.coord_arrays)
            else:
                old_data = self.check_old_data_is_okay_to_use()
                iterable = self.input_data
    
                if old_data is not False:
                    self.data = old_data
                    return
    
            if hasattr(self, 'search') is False:
                self.inititate_search_object()
    
            data = []
            for vals in tqdm(iterable,
                             total=getattr(self, 'total_iterations', None)):
                detstat = self.search.get_det_stat(*vals)
                thisCand = list(vals) + [detstat]
                data.append(thisCand)
    
            data = np.array(data, dtype=np.float)
            if return_data:
                return data
            else:
                self.save_array_to_disk(data)
                self.data = data
    
        def get_header(self):
            header = ';'.join(['date:{}'.format(str(datetime.datetime.now())),
                               'user:{}'.format(getpass.getuser()),
                               'hostname:{}'.format(socket.gethostname())])
            header += '\n' + ' '.join(self.keys)
            return header
    
        def save_array_to_disk(self, data):
            logging.info('Saving data to {}'.format(self.out_file))
            header = self.get_header()
            np.savetxt(self.out_file, data, delimiter=' ', header=header)
    
        def convert_F0_to_mismatch(self, F0, F0hat, Tseg):
            DeltaF0 = F0[1] - F0[0]
            m_spacing = (np.pi*Tseg*DeltaF0)**2 / 12.
            N = len(F0)
            return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing)
    
        def convert_F1_to_mismatch(self, F1, F1hat, Tseg):
            DeltaF1 = F1[1] - F1[0]
            m_spacing = (np.pi*Tseg**2*DeltaF1)**2 / 720.
            N = len(F1)
            return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing)
    
        def add_mismatch_to_ax(self, ax, x, y, xkey, ykey, xhat, yhat, Tseg):
            axX = ax.twiny()
            axX.zorder = -10
            axY = ax.twinx()
            axY.zorder = -10
    
            if xkey == 'F0':
                m = self.convert_F0_to_mismatch(x, xhat, Tseg)
                axX.set_xlim(m[0], m[-1])
    
            if ykey == 'F1':
                m = self.convert_F1_to_mismatch(y, yhat, Tseg)
                axY.set_ylim(m[0], m[-1])
    
        def plot_1D(self, xkey, ax=None, x0=None, xrescale=1, savefig=True,
                    xlabel=None, ylabel='$\widetilde{2\mathcal{F}}$'):
            if ax is None:
                fig, ax = plt.subplots()
            xidx = self.keys.index(xkey)
            x = np.unique(self.data[:, xidx])
            if x0:
                x = x - x0
            x = x * xrescale
            z = self.data[:, -1]
            ax.plot(x, z)
            if x0:
                ax.set_xlabel(self.tex_labels[xkey]+self.tex_labels0[xkey])
            else:
                ax.set_xlabel(self.tex_labels[xkey])
    
            if xlabel:
                ax.set_xlabel(xlabel)
    
            ax.set_ylabel(ylabel)
            if savefig:
                fig.tight_layout()
                fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
            else:
                return ax
    
        def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None,
                    add_mismatch=None, xN=None, yN=None, flat_keys=[],
                    rel_flat_idxs=[], flatten_method=np.max, title=None,
                    predicted_twoF=None, cm=None, cbarkwargs={}, x0=None, y0=None,
                    colorbar=False, xrescale=1, yrescale=1):
            """ Plots a 2D grid of 2F values
    
            Parameters
            ----------
            add_mismatch: tuple (xhat, yhat, Tseg)
                If not None, add a secondary axis with the metric mismatch from the
                point xhat, yhat with duration Tseg
            flatten_method: np.max
                Function to use in flattening the flat_keys
            """
            if ax is None:
                fig, ax = plt.subplots()
            xidx = self.keys.index(xkey)
            yidx = self.keys.index(ykey)
            flat_idxs = [self.keys.index(k) for k in flat_keys]
    
            x = np.unique(self.data[:, xidx])
            if x0:
                x = x-x0
            y = np.unique(self.data[:, yidx])
            if y0:
                y = y-y0
            flat_vals = [np.unique(self.data[:, j]) for j in flat_idxs]
            z = self.data[:, -1]
    
            Y, X = np.meshgrid(y, x)
            shape = [len(x), len(y)] + [len(v) for v in flat_vals]
            Z = z.reshape(shape)
    
            if len(rel_flat_idxs) > 0:
                Z = flatten_method(Z, axis=tuple(rel_flat_idxs))
    
            if predicted_twoF:
                Z = (predicted_twoF - Z) / (predicted_twoF + 4)
                if cm is None:
                    cm = plt.cm.viridis_r
            else:
                if cm is None:
                    cm = plt.cm.viridis
    
            pax = ax.pcolormesh(
                X*xrescale, Y*yrescale, Z, cmap=cm, vmin=vmin, vmax=vmax)
            if colorbar:
                cb = plt.colorbar(pax, ax=ax, **cbarkwargs)
                cb.set_label('$2\mathcal{F}$')
    
            if add_mismatch:
                self.add_mismatch_to_ax(ax, x, y, xkey, ykey, *add_mismatch)
    
            ax.set_xlim(x[0]*xrescale, x[-1]*xrescale)
            ax.set_ylim(y[0]*yrescale, y[-1]*yrescale)
            if x0:
                ax.set_xlabel(self.tex_labels[xkey]+self.tex_labels0[xkey])
            else:
                ax.set_xlabel(self.tex_labels[xkey])
            if y0:
                ax.set_ylabel(self.tex_labels[ykey]+self.tex_labels0[ykey])
            else:
                ax.set_ylabel(self.tex_labels[ykey])
    
            if title:
                ax.set_title(title)
    
            if xN:
                ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(xN))
            if yN:
                ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(yN))
    
            if save:
                fig.tight_layout()
                fig.savefig('{}/{}_2D.png'.format(self.outdir, self.label))
            else:
                return ax
    
        def get_max_twoF(self):
            """ Get the maximum twoF over the grid
    
            Returns
            -------
            d: dict
                Dictionary containing, 'minStartTime', 'maxStartTime', 'F0', 'F1',
                'F2', 'Alpha', 'Delta' and 'twoF' of maximum
    
            """
    
            twoF = self.data[:, -1]
            idx = np.argmax(twoF)
            v = self.data[idx, :]
            d = OrderedDict(minStartTime=v[0], maxStartTime=v[1], F0=v[2], F1=v[3],
                            F2=v[4], Alpha=v[5], Delta=v[6], twoF=v[7])
            return d
    
        def print_max_twoF(self):
            d = self.get_max_twoF()
            print('Max twoF values for {}:'.format(self.label))
            for k, v in d.iteritems():
                print('  {}={}'.format(k, v))
    
        def set_out_file(self, extra_label=None):
            if self.detectors:
                dets = self.detectors.replace(',', '')
            else:
                dets = 'NA'
            if extra_label:
                self.out_file = '{}/{}_{}_{}_{}.txt'.format(
                    self.outdir, self.label, dets, type(self).__name__,
                    extra_label)
            else:
                self.out_file = '{}/{}_{}_{}.txt'.format(
                    self.outdir, self.label, dets, type(self).__name__)
    
    
    class TransientGridSearch(GridSearch):
        """ Gridded transient-continous search using ComputeFstat """
    
        @helper_functions.initializer
        def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
                     Deltas, tref=None, minStartTime=None, maxStartTime=None,
                     BSGL=False, minCoverFreq=None, maxCoverFreq=None,
                     detectors=None, SSBprec=None, injectSources=None,
                     input_arrays=False, assumeSqrtSX=None,
                     transientWindowType=None, t0Band=None, tauBand=None,
                     dt0=None, dtau=None,
                     outputTransientFstatMap=False,
                     outputAtoms=False,
                     tCWFstatMapVersion='lal', cudaDeviceName=None):
            """
            Parameters
            ----------
            label, outdir: str
                A label and directory to read/write data from/to
            sftfilepattern: str
                Pattern to match SFTs using wildcards (*?) and ranges [0-9];
                mutiple patterns can be given separated by colons.
            F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple
                Length 3 tuple describing the grid for each parameter, e.g
                [F0min, F0max, dF0], for a fixed value simply give [F0]. Unless
                input_arrays == True, then these are the values to search at.
            tref, minStartTime, maxStartTime: int
                GPS seconds of the reference time, start time and end time
            input_arrays: bool
                if true, use the F0s, F1s, etc as is
            transientWindowType: str
                If 'rect' or 'exp', compute atoms so that a transient (t0,tau) map
                can later be computed.  ('none' instead of None explicitly calls
                the transient-window function, but with the full range, for
                debugging). Currently only supported for nsegs=1.
            t0Band, tauBand: int
                if >0, search t0 in (minStartTime,minStartTime+t0Band)
                       and tau in (2*Tsft,2*Tsft+tauBand).
                if =0, only compute CW Fstat with t0=minStartTime,
                       tau=maxStartTime-minStartTime.
            dt0, dtau: int
                grid resolutions in transient start-time and duration,
                both default to Tsft
            outputTransientFstatMap: bool
                if true, write output files for (t0,tau) Fstat maps
                (one file for each doppler grid point!)
            tCWFstatMapVersion: str
                Choose between standard 'lal' implementation,
                'pycuda' for gpu, and some others for devel/debug.
            cudaDeviceName: str
                GPU name to be matched against drv.Device output.
    
            For all other parameters, see `pyfstat.ComputeFStat` for details
            """
    
            self.nsegs = 1
            if os.path.isdir(outdir) is False:
                os.mkdir(outdir)
            self.set_out_file()
            self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
            self.search_keys = [x+'s' for x in self.keys[2:]]
            for k in self.search_keys:
                setattr(self, k, np.atleast_1d(getattr(self, k)))
    
        def inititate_search_object(self):
            logging.info('Setting up search object')
            self.search = ComputeFstat(
                tref=self.tref, sftfilepattern=self.sftfilepattern,
                minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
                detectors=self.detectors,
                transientWindowType=self.transientWindowType,
                t0Band=self.t0Band, tauBand=self.tauBand,
                dt0=self.dt0, dtau=self.dtau,
                minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
                BSGL=self.BSGL, SSBprec=self.SSBprec,
                injectSources=self.injectSources,
                assumeSqrtSX=self.assumeSqrtSX,
                tCWFstatMapVersion=self.tCWFstatMapVersion,
                cudaDeviceName=self.cudaDeviceName)
            self.search.get_det_stat = self.search.get_fullycoherent_twoF
    
        def run(self, return_data=False):
            self.get_input_data_array()
            old_data = self.check_old_data_is_okay_to_use()
            if old_data is not False:
                self.data = old_data
                return
    
            if hasattr(self, 'search') is False:
                self.inititate_search_object()
    
            data = []
            if self.outputTransientFstatMap:
                tCWfilebase = os.path.splitext(self.out_file)[0] + '_tCW_'
                logging.info('Will save per-Doppler Fstatmap' \
                             ' results to {}*.dat'.format(tCWfilebase))
            for vals in tqdm(self.input_data):
                detstat = self.search.get_det_stat(*vals)
                windowRange = getattr(self.search, 'windowRange', None)
                FstatMap = getattr(self.search, 'FstatMap', None)
                thisCand = list(vals) + [detstat]
                if getattr(self, 'transientWindowType', None):
                    if self.tCWFstatMapVersion == 'lal':
                        F_mn = FstatMap.F_mn.data
                    else:
                        F_mn = FstatMap.F_mn
                    if self.outputTransientFstatMap:
                        # per-Doppler filename convention:
                        # freq alpha delta f1dot f2dot
                        tCWfile = ( tCWfilebase
                                    + '%.16f_%.16f_%.16f_%.16g_%.16g.dat' %
                                    (vals[2],vals[5],vals[6],vals[3],vals[4]) )
                        if self.tCWFstatMapVersion == 'lal':
                            fo = lal.FileOpen(tCWfile, 'w')
                            lalpulsar.write_transientFstatMap_to_fp (
                                fo, FstatMap, windowRange, None )
                            # instead of lal.FileClose(),
                            # which is not SWIG-exported:
                            del fo
                        else:
                            self.write_F_mn ( tCWfile, F_mn, windowRange)
                    maxidx = np.unravel_index(F_mn.argmax(), F_mn.shape)
                    thisCand += [windowRange.t0+maxidx[0]*windowRange.dt0,
                                 windowRange.tau+maxidx[1]*windowRange.dtau]
                data.append(thisCand)
                if self.outputAtoms:
                    self.search.write_atoms_to_file(os.path.splitext(self.out_file)[0])
    
            data = np.array(data, dtype=np.float)
            if return_data:
                return data
            else:
                self.save_array_to_disk(data)
                self.data = data
    
        def write_F_mn (self, tCWfile, F_mn, windowRange ):
            with open(tCWfile, 'w') as tfp:
                tfp.write('# t0 [s]     tau [s]     2F\n')
                for m, F_m in enumerate(F_mn):
                    this_t0 = windowRange.t0 + m * windowRange.dt0
                    for n, this_F in enumerate(F_m):
                        this_tau = windowRange.tau + n * windowRange.dtau;
                        tfp.write('  %10d %10d %- 11.8g\n' % (this_t0, this_tau, 2.0*this_F))
    
        def __del__(self):
            if hasattr(self,'search'):
                self.search.__del__()
    
    
    class SliceGridSearch(GridSearch):
        """ Slice gridded search using ComputeFstat """
        @helper_functions.initializer
        def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
                     Deltas, tref=None, minStartTime=None, maxStartTime=None,
                     nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None,
                     detectors=None, SSBprec=None, injectSources=None,
                     input_arrays=False, assumeSqrtSX=None, Lambda0=None):
            """
            Parameters
            ----------
            label, outdir: str
                A label and directory to read/write data from/to
            sftfilepattern: str
                Pattern to match SFTs using wildcards (*?) and ranges [0-9];
                mutiple patterns can be given separated by colons.
            F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple
                Length 3 tuple describing the grid for each parameter, e.g
                [F0min, F0max, dF0], for a fixed value simply give [F0]. Unless
                input_arrays == True, then these are the values to search at.
            tref, minStartTime, maxStartTime: int
                GPS seconds of the reference time, start time and end time
            input_arrays: bool
                if true, use the F0s, F1s, etc as is
    
            For all other parameters, see `pyfstat.ComputeFStat` for details
            """
    
            if os.path.isdir(outdir) is False:
                os.mkdir(outdir)
            self.set_out_file()
            self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
            self.ndim = 0
            self.thetas = [F0s, F1s, Alphas, Deltas]
            self.ndim = 4
    
            self.search_keys = ['F0', 'F1', 'Alpha', 'Delta']
            if self.Lambda0 is None:
                raise ValueError('Lambda0 undefined')
            if len(self.Lambda0) != len(self.search_keys):
                raise ValueError(
                    'Lambda0 must be of length {}'.format(len(self.search_keys)))
            self.Lambda0 = np.array(Lambda0)
    
        def run(self, factor=2, max_n_ticks=4, whspace=0.07, save=True,
                **kwargs):
            lbdim = 0.5 * factor   # size of left/bottom margin
            trdim = 0.4 * factor   # size of top/right margin
            plotdim = factor * self.ndim + factor * (self.ndim - 1.) * whspace
            dim = lbdim + plotdim + trdim
    
            fig, axes = plt.subplots(self.ndim, self.ndim, figsize=(dim, dim))
    
            # Format the figure.
            lb = lbdim / dim
            tr = (lbdim + plotdim) / dim
            fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr,
                                wspace=whspace, hspace=whspace)
    
            search = GridSearch(
                self.label, self.outdir, self.sftfilepattern,
                F0s=self.Lambda0[0], F1s=self.Lambda0[1], F2s=self.F2s[0],
                Alphas=self.Lambda0[2], Deltas=self.Lambda0[3], tref=self.tref,
                minStartTime=self.minStartTime, maxStartTime=self.maxStartTime)
    
            for i, ikey in enumerate(self.search_keys):
                setattr(search, ikey+'s', self.thetas[i])
                search.label = '{}_{}'.format(self.label, ikey)
                search.set_out_file()
                search.run()
                axes[i, i] = search.plot_1D(ikey, ax=axes[i, i], savefig=False,
                                            x0=self.Lambda0[i]
                                            )
                setattr(search, ikey+'s', [self.Lambda0[i]])
                axes[i, i].yaxis.tick_right()
                axes[i, i].yaxis.set_label_position("right")
                axes[i, i].set_xlabel('')
    
                for j, jkey in enumerate(self.search_keys):
                    ax = axes[i, j]
    
                    if j > i:
                        ax.set_frame_on(False)
                        ax.set_xticks([])
                        ax.set_yticks([])
                        continue
    
                    ax.get_shared_x_axes().join(axes[self.ndim-1, j], ax)
                    if i < self.ndim - 1:
                        ax.set_xticklabels([])
                    if j < i:
                        ax.get_shared_y_axes().join(axes[i, i-1], ax)
                        if j > 0:
                            ax.set_yticklabels([])
                    if j == i:
                        continue
    
                    ax.xaxis.set_major_locator(
                        matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper"))
                    ax.yaxis.set_major_locator(
                        matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper"))
    
                    setattr(search, ikey+'s', self.thetas[i])
                    setattr(search, jkey+'s', self.thetas[j])
                    search.label = '{}_{}'.format(self.label, ikey+jkey)
                    search.set_out_file()
                    search.run()
                    ax = search.plot_2D(jkey, ikey, ax=ax, save=False,
                                        y0=self.Lambda0[i], x0=self.Lambda0[j],
                                        **kwargs)
                    setattr(search, ikey+'s', [self.Lambda0[i]])
                    setattr(search, jkey+'s', [self.Lambda0[j]])
    
                    ax.grid(lw=0.2, ls='--', zorder=10)
                    ax.set_xlabel('')
                    ax.set_ylabel('')
    
            for i, ikey in enumerate(self.search_keys):
                axes[-1, i].set_xlabel(
                    self.tex_labels[ikey]+self.tex_labels0[ikey])
                if i > 0:
                    axes[i, 0].set_ylabel(
                        self.tex_labels[ikey]+self.tex_labels0[ikey])
                axes[i, i].set_ylabel("$2\mathcal{F}$")
    
            if save:
                fig.savefig(
                    '{}/{}_slice_projection.png'.format(self.outdir, self.label))
            else:
                return fig, axes
    
    
    class GridUniformPriorSearch():
        @helper_functions.initializer
        def __init__(self, theta_prior, NF0, NF1, label, outdir, sftfilepattern,
                     tref, minStartTime, maxStartTime, minCoverFreq=None,
                     maxCoverFreq=None, BSGL=False, detectors=None, nsegs=1,
                     SSBprec=None, injectSources=None):
            dF0 = (theta_prior['F0']['upper'] - theta_prior['F0']['lower'])/NF0
            dF1 = (theta_prior['F1']['upper'] - theta_prior['F1']['lower'])/NF1
            F0s = [theta_prior['F0']['lower'], theta_prior['F0']['upper'], dF0]
            F1s = [theta_prior['F1']['lower'], theta_prior['F1']['upper'], dF1]
            self.search = GridSearch(
                label, outdir, sftfilepattern, F0s=F0s, F1s=F1s, tref=tref,
                Alphas=[theta_prior['Alpha']], Deltas=[theta_prior['Delta']],
                minStartTime=minStartTime, maxStartTime=maxStartTime, BSGL=BSGL,
                detectors=detectors, minCoverFreq=minCoverFreq,
                injectSources=injectSources, maxCoverFreq=maxCoverFreq,
                nsegs=nsegs, SSBprec=SSBprec)
    
        def run(self):
            self.search.run()
    
        def get_2D_plot(self, **kwargs):
            return self.search.plot_2D('F0', 'F1', **kwargs)
    
    
    class GridGlitchSearch(GridSearch):
        """ Grid search using the SemiCoherentGlitchSearch """
        search_labels = ['F0s', 'F1s', 'F2s', 'Alphas', 'Deltas', 'delta_F0s',
                         'delta_F1s', 'tglitchs']
    
        @helper_functions.initializer
        def __init__(self, label, outdir='data', sftfilepattern=None, F0s=[0],
                     F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None,
                     Alphas=[0], Deltas=[0], tref=None, minStartTime=None,
                     maxStartTime=None, minCoverFreq=None, maxCoverFreq=None,
                     detectors=None):
            """
            Run a single-glitch grid search
    
            Parameters
            ----------
            label, outdir: str
                A label and directory to read/write data from/to
            sftfilepattern: str
                Pattern to match SFTs using wildcards (*?) and ranges [0-9];
                mutiple patterns can be given separated by colons.
            F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple
                Length 3 tuple describing the grid for each parameter, e.g
                [F0min, F0max, dF0], for a fixed value simply give [F0]. Note that
                tglitchs is referenced to zero at minStartTime.
            tref, minStartTime, maxStartTime: int
                GPS seconds of the reference time, start time and end time
    
            For all other parameters, see pyfstat.ComputeFStat.
            """
    
            self.BSGL = False
            self.input_arrays = False
            if tglitchs is None:
                raise ValueError('You must specify `tglitchs`')
    
            self.search = SemiCoherentGlitchSearch(
                label=label, outdir=outdir, sftfilepattern=self.sftfilepattern,
                tref=tref, minStartTime=minStartTime, maxStartTime=maxStartTime,
                minCoverFreq=minCoverFreq, maxCoverFreq=maxCoverFreq,
                BSGL=self.BSGL)
            self.search.get_det_stat = self.search.get_semicoherent_nglitch_twoF
    
            if os.path.isdir(outdir) is False:
                os.mkdir(outdir)
            self.set_out_file()
            self.keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta', 'delta_F0',
                         'delta_F1', 'tglitch']
    
    
    class SlidingWindow(GridSearch):
        @helper_functions.initializer
        def __init__(self, label, outdir, sftfilepattern, F0, F1, F2,
                     Alpha, Delta, tref, minStartTime=None,
                     maxStartTime=None, window_size=10*86400, window_delta=86400,
                     BSGL=False, minCoverFreq=None, maxCoverFreq=None,
                     detectors=None, SSBprec=None, injectSources=None):
            """
            Parameters
            ----------
            label, outdir: str
                A label and directory to read/write data from/to
            sftfilepattern: str
                Pattern to match SFTs using wildcards (*?) and ranges [0-9];
                mutiple patterns can be given separated by colons.
            F0, F1, F2, Alpha, Delta: float
                Fixed values to compute output over
            tref, minStartTime, maxStartTime: int
                GPS seconds of the reference time, start time and end time
    
            For all other parameters, see `pyfstat.ComputeFStat` for details
            """
    
            if os.path.isdir(outdir) is False:
                os.mkdir(outdir)
            self.set_out_file()
            self.nsegs = 1
    
            self.tstarts = [self.minStartTime]
            while self.tstarts[-1] + self.window_size < self.maxStartTime:
                self.tstarts.append(self.tstarts[-1]+self.window_delta)
            self.tmids = np.array(self.tstarts) + .5 * self.window_size
    
        def inititate_search_object(self):
            logging.info('Setting up search object')
            self.search = ComputeFstat(
                tref=self.tref, sftfilepattern=self.sftfilepattern,
                minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
                detectors=self.detectors, transient=True,
                minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
                BSGL=self.BSGL, SSBprec=self.SSBprec,
                injectSources=self.injectSources)
    
        def check_old_data_is_okay_to_use(self, out_file):
            if os.path.isfile(out_file):
                tmids, vals, errvals = np.loadtxt(out_file).T
                if len(tmids) == len(self.tmids) and (
                        tmids[0] == self.tmids[0]):
                    self.vals = vals
                    self.errvals = errvals
                    return True
            return False
    
        def run(self, key='h0', errkey='dh0'):
            self.key = key
            self.errkey = errkey
            out_file = '{}/{}_{}-sliding-window.txt'.format(
                self.outdir, self.label, key)
    
            if self.check_old_data_is_okay_to_use(out_file) is False:
                self.inititate_search_object()
                vals = []
                errvals = []
                for ts in self.tstarts:
                    loudest = self.search.get_full_CFSv2_output(
                            ts, ts+self.window_size, self.F0, self.F1, self.F2,
                            self.Alpha, self.Delta, self.tref)
                    vals.append(loudest[key])
                    errvals.append(loudest[errkey])
    
                np.savetxt(out_file, np.array([self.tmids, vals, errvals]).T)
                self.vals = np.array(vals)
                self.errvals = np.array(errvals)
    
        def plot_sliding_window(self, factor=1, fig=None, ax=None):
            if ax is None:
                fig, ax = plt.subplots()
            days = (self.tmids-self.minStartTime) / 86400
            ax.errorbar(days, self.vals*factor, yerr=self.errvals*factor)
            ax.set_ylabel(self.key)
            ax.set_xlabel(
                r'Mid-point (days after $t_\mathrm{{start}}$={})'.format(
                    self.minStartTime))
            ax.set_title(
                'Sliding window of {} days in increments of {} days'
                .format(self.window_size/86400, self.window_delta/86400),
                )
    
            if fig:
                fig.savefig('{}/{}_{}-sliding-window.png'.format(
                    self.outdir, self.label, self.key))
            else:
                return ax
    
    
    class FrequencySlidingWindow(GridSearch):
        """ A sliding-window search over the Frequency """
        @helper_functions.initializer
        def __init__(self, label, outdir, sftfilepattern, F0s, F1, F2,
                     Alpha, Delta, tref, minStartTime=None,
                     maxStartTime=None, window_size=10*86400, window_delta=86400,
                     BSGL=False, minCoverFreq=None, maxCoverFreq=None,
                     detectors=None, SSBprec=None, injectSources=None):
            """
            Parameters
            ----------
            label, outdir: str
                A label and directory to read/write data from/to
            sftfilepattern: str
                Pattern to match SFTs using wildcards (*?) and ranges [0-9];
                mutiple patterns can be given separated by colons.
            F0s: array
                Frequency range
            F1, F2, Alpha, Delta: float
                Fixed values to compute twoF(F) over
            tref, minStartTime, maxStartTime: int
                GPS seconds of the reference time, start time and end time
    
            For all other parameters, see `pyfstat.ComputeFStat` for details
            """
    
            self.transientWindowType = 'rect'
            self.nsegs = 1
            self.t0Band = None
            self.tauBand = None
    
            if os.path.isdir(outdir) is False:
                os.mkdir(outdir)
            self.set_out_file()
            self.F1s = [F1]
            self.F2s = [F2]
            self.Alphas = [Alpha]
            self.Deltas = [Delta]
            self.input_arrays = False
            self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
    
        def inititate_search_object(self):
            logging.info('Setting up search object')
            self.search = ComputeFstat(
                tref=self.tref, sftfilepattern=self.sftfilepattern,
                minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
                detectors=self.detectors, transientWindowType=self.transientWindowType,
                minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
                BSGL=self.BSGL, SSBprec=self.SSBprec,
                injectSources=self.injectSources)
            self.search.get_det_stat = (
                self.search.get_fullycoherent_twoF)
    
        def get_input_data_array(self):
            coord_arrays = []
            tstarts = [self.minStartTime]
            while tstarts[-1] + self.window_size < self.maxStartTime:
                tstarts.append(tstarts[-1]+self.window_delta)
            coord_arrays = [tstarts]
            for tup in (self.F0s, self.F1s, self.F2s,
                        self.Alphas, self.Deltas):
                coord_arrays.append(self.get_array_from_tuple(tup))
    
            input_data = []
            for vals in itertools.product(*coord_arrays):
                input_data.append(vals)
    
            input_data = np.array(input_data)
            input_data = np.insert(
                input_data, 1, input_data[:, 0] + self.window_size, axis=1)
    
            self.coord_arrays = coord_arrays
            self.input_data = np.array(input_data)
    
        def plot_sliding_window(self, F0=None, ax=None, savefig=True,
                                colorbar=True, timestamps=False,
                                F0rescale=1, **kwargs):
            data = self.data
            if ax is None:
                ax = plt.subplot()
            tstarts = np.unique(data[:, 0])
            tends = np.unique(data[:, 1])
            frequencies = np.unique(data[:, 2])
            twoF = data[:, -1]
            tmids = (tstarts + tends) / 2.0
            dts = (tmids - self.minStartTime) / 86400.
            if F0:
                frequencies = frequencies - F0
                ax.set_ylabel('Frequency - $f_0$ [Hz] \n $f_0={:0.2f}$'.format(F0))
            else:
                ax.set_ylabel('Frequency [Hz]')
            twoF = twoF.reshape((len(tmids), len(frequencies)))
            Y, X = np.meshgrid(frequencies, dts)
            pax = ax.pcolormesh(X, Y*F0rescale, twoF, **kwargs)
            if colorbar:
                cb = plt.colorbar(pax, ax=ax)
                cb.set_label('$2\mathcal{F}$')
            ax.set_xlabel(
                r'Mid-point (days after $t_\mathrm{{start}}$={})'.format(
                    self.minStartTime))
            ax.set_title(
                'Sliding window length = {} days in increments of {} days'
                .format(self.window_size/86400, self.window_delta/86400),
                )
            if timestamps:
                axT = ax.twiny()
                axT.set_xlim(tmids[0]*1e-9, tmids[-1]*1e-9)
                axT.set_xlabel('Mid-point timestamp [GPS $10^{9}$ s]')
                ax.set_title(ax.get_title(), y=1.18)
            if savefig:
                plt.tight_layout()
                plt.savefig(
                    '{}/{}_sliding_window.png'.format(self.outdir, self.label))
            else:
                return ax
    
    
    class EarthTest(GridSearch):
        """ """
        tex_labels = {'deltaRadius': '$\Delta R$ [m]',
                      'phaseOffset': 'phase-offset [rad]',
                      'deltaPspin': '$\Delta P_\mathrm{spin}$ [s]'}
    
        @helper_functions.initializer
        def __init__(self, label, outdir, sftfilepattern, deltaRadius,
                     phaseOffset, deltaPspin, F0, F1, F2, Alpha,
                     Delta, tref=None, minStartTime=None, maxStartTime=None,
                     BSGL=False, minCoverFreq=None, maxCoverFreq=None,
                     detectors=None, injectSources=None,
                     assumeSqrtSX=None):
            """
            Parameters
            ----------
            label, outdir: str
                A label and directory to read/write data from/to
            sftfilepattern: str
                Pattern to match SFTs using wildcards (*?) and ranges [0-9];
                mutiple patterns can be given separated by colons.
            F0, F1, F2, Alpha, Delta: float
            tref, minStartTime, maxStartTime: int
                GPS seconds of the reference time, start time and end time
    
            For all other parameters, see `pyfstat.ComputeFStat` for details
            """
            self.transientWindowType = None
            self.t0Band = None
            self.tauBand = None
    
            if os.path.isdir(outdir) is False:
                os.mkdir(outdir)
            self.nsegs = 1
            self.F0s = [F0]
            self.F1s = [F1]
            self.F2s = [F2]
            self.Alphas = [Alpha]
            self.Deltas = [Delta]
            self.duration = maxStartTime - minStartTime
            self.deltaRadius = np.atleast_1d(deltaRadius)
            self.phaseOffset = np.atleast_1d(phaseOffset)
            self.phaseOffset = self.phaseOffset + 1e-12  # Hack to stop cached data being used
            self.deltaPspin = np.atleast_1d(deltaPspin)
            self.set_out_file()
            self.SSBprec = lalpulsar.SSBPREC_RELATIVISTIC
            self.keys = ['deltaRadius', 'phaseOffset', 'deltaPspin']
    
            self.prior_widths = [
                np.max(self.deltaRadius)-np.min(self.deltaRadius),
                np.max(self.phaseOffset)-np.min(self.phaseOffset),
                np.max(self.deltaPspin)-np.min(self.deltaPspin)]
    
            if hasattr(self, 'search') is False:
                self.inititate_search_object()
    
        def get_input_data_array(self):
            logging.info("Generating input data array")
            coord_arrays = [self.deltaRadius, self.phaseOffset, self.deltaPspin]
            input_data = []
            for vals in itertools.product(*coord_arrays):
                    input_data.append(vals)
            self.input_data = np.array(input_data)
            self.coord_arrays = coord_arrays
    
        def run_special(self):
            vals = [self.minStartTime, self.maxStartTime, self.F0, self.F1,
                    self.F2, self.Alpha, self.Delta]
            self.special_data = {'zero': [0, 0, 0]}
            for key, (dR, dphi, dP) in self.special_data.iteritems():
                rescaleRadius = (1 + dR / lal.REARTH_SI)
                rescalePeriod = (1 + dP / lal.DAYSID_SI)
                lalpulsar.BarycenterModifyEarthRotation(
                    rescaleRadius, dphi, rescalePeriod, self.tref)
                FS = self.search.get_det_stat(*vals)
                self.special_data[key] = list([dR, dphi, dP]) + [FS]
    
        def run(self):
            self.run_special()
            self.get_input_data_array()
            old_data = self.check_old_data_is_okay_to_use()
            if old_data is not False:
                self.data = old_data
                return
    
            data = []
            vals = [self.minStartTime, self.maxStartTime, self.F0, self.F1,
                    self.F2, self.Alpha, self.Delta]
            for (dR, dphi, dP) in tqdm(self.input_data):
                rescaleRadius = (1 + dR / lal.REARTH_SI)
                rescalePeriod = (1 + dP / lal.DAYSID_SI)
                lalpulsar.BarycenterModifyEarthRotation(
                    rescaleRadius, dphi, rescalePeriod, self.tref)
                FS = self.search.get_det_stat(*vals)
                data.append(list([dR, dphi, dP]) + [FS])
    
            data = np.array(data, dtype=np.float)
            logging.info('Saving data to {}'.format(self.out_file))
            np.savetxt(self.out_file, data, delimiter=' ')
            self.data = data
    
        def marginalised_bayes_factor(self, prior_widths=None):
            if prior_widths is None:
                prior_widths = self.prior_widths
    
            ndims = self.data.shape[1] - 1
            params = np.array([np.unique(self.data[:, j]) for j in range(ndims)])
            twoF = self.data[:, -1].reshape(tuple([len(p) for p in params]))
            F = twoF / 2.0
            for i, x in enumerate(params[::-1]):
                if len(x) > 1:
                    dx = x[1] - x[0]
                    F = logsumexp(F, axis=-1)+np.log(dx)-np.log(prior_widths[-1-i])
                else:
                    F = np.squeeze(F, axis=-1)
            marginalised_F = np.atleast_1d(F)[0]
            F_at_zero = self.special_data['zero'][-1]/2.0
    
            max_idx = np.argmax(self.data[:, -1])
            max_F = self.data[max_idx, -1]/2.0
            max_F_params = self.data[max_idx, :-1]
            logging.info('F at zero = {:.1f}, marginalised_F = {:.1f},'
                         ' max_F = {:.1f} ({})'.format(
                             F_at_zero, marginalised_F, max_F, max_F_params))
            return F_at_zero - marginalised_F, (F_at_zero - max_F) / F_at_zero
    
        def plot_corner(self, prior_widths=None, fig=None, axes=None,
                        projection='log_mean'):
            Bsa, FmaxMismatch = self.marginalised_bayes_factor(prior_widths)
    
            data = self.data[:, -1].reshape(
                (len(self.deltaRadius), len(self.phaseOffset),
                 len(self.deltaPspin)))
            xyz = [self.deltaRadius/lal.REARTH_SI, self.phaseOffset/(np.pi),
                   self.deltaPspin/60.]
            labels = [r'$\frac{\Delta R}{R_\mathrm{Earth}}$',
                      r'$\frac{\Delta \phi}{\pi}$',
                      r'$\Delta P_\mathrm{spin}$ [min]',
                      r'$2\mathcal{F}$']
    
            try:
                from gridcorner import gridcorner
            except ImportError:
                raise ImportError(
                    "Python module 'gridcorner' not found, please install from "
                    "https://gitlab.aei.uni-hannover.de/GregAshton/gridcorner")
    
            fig, axes = gridcorner(data, xyz, projection=projection, factor=1.6,
                                   labels=labels)
            axes[-1][-1].axvline((lal.DAYJUL_SI - lal.DAYSID_SI)/60.0, color='C3')
            plt.suptitle(
                'T={:.1f} days, $f$={:.2f} Hz, $\log\mathcal{{B}}_{{S/A}}$={:.1f},'
                r' $\frac{{\mathcal{{F}}_0-\mathcal{{F}}_\mathrm{{max}}}}'
                r'{{\mathcal{{F}}_0}}={:.1e}$'
                .format(self.duration/86400, self.F0, Bsa, FmaxMismatch), y=0.99,
                size=14)
            fig.savefig('{}/{}_projection_matrix.png'.format(
                self.outdir, self.label))
    
        def plot(self, key, prior_widths=None):
            Bsa, FmaxMismatch = self.marginalised_bayes_factor(prior_widths)
    
            rescales_defaults = {'deltaRadius': 1/lal.REARTH_SI,
                                 'phaseOffset': 1/np.pi,
                                 'deltaPspin': 1}
            labels = {'deltaRadius': r'$\frac{\Delta R}{R_\mathrm{Earth}}$',
                      'phaseOffset': r'$\frac{\Delta \phi}{\pi}$',
                      'deltaPspin': r'$\Delta P_\mathrm{spin}$ [s]'
                      }
    
            fig, ax = self.plot_1D(key, xrescale=rescales_defaults[key],
                                   xlabel=labels[key], savefig=False)
            ax.set_title(
                'T={} days, $f$={} Hz, $\log\mathcal{{B}}_{{S/A}}$={:.1f}'
                .format(self.duration/86400, self.F0, Bsa))
            fig.tight_layout()
            fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
    
    
    class DMoff_NO_SPIN(GridSearch):
        """ DMoff test using SSBPREC_NO_SPIN """
        @helper_functions.initializer
        def __init__(self, par, label, outdir, sftfilepattern, minStartTime=None,
                     maxStartTime=None, minCoverFreq=None, maxCoverFreq=None,
                     detectors=None, injectSources=None, assumeSqrtSX=None):
            """
            Parameters
            ----------
            par: dict, str
                Either a par dictionary (containing 'F0', 'F1', 'Alpha', 'Delta'
                and 'tref') or a path to a .par file to read in the F0, F1 etc
            label, outdir: str
                A label and directory to read/write data from/to
            sftfilepattern: str
                Pattern to match SFTs using wildcards (*?) and ranges [0-9];
                mutiple patterns can be given separated by colons.
            minStartTime, maxStartTime: int
                GPS seconds of the start time and end time
    
            For all other parameters, see `pyfstat.ComputeFStat` for details
            """
    
            if os.path.isdir(outdir) is False:
                os.mkdir(outdir)
    
            if type(par) == dict:
                self.par = par
            elif type(par) == str and os.path.isfile(par):
                self.par = read_par(filename=par)
            else:
                raise ValueError('The .par file does not exist')
    
            self.nsegs = 1
            self.BSGL = False
    
            self.tref = self.par['tref']
            self.F1s = [self.par.get('F1', 0)]
            self.F2s = [self.par.get('F2', 0)]
            self.Alphas = [self.par['Alpha']]
            self.Deltas = [self.par['Delta']]
            self.Re = 6.371e6
            self.c = 2.998e8
            a0 = self.Re/self.c  # *np.cos(self.par['Delta'])
            self.m0 = np.max([4, int(np.ceil(2*np.pi*self.par['F0']*a0))])
            logging.info(
                'Setting up DMoff_NO_SPIN search with m0 = {}'.format(self.m0))
    
        def get_results(self):
            """ Compute the three summed detection statistics
    
            Returns
            -------
                m0, twoF_SUM, twoFstar_SUM_SIDEREAL, twoFstar_SUM_TERRESTRIAL
    
            """
            self.SSBprec = lalpulsar.SSBPREC_RELATIVISTIC
            self.set_out_file('SSBPREC_RELATIVISTIC')
            self.F0s = [self.par['F0']+j/lal.DAYSID_SI for j in range(-4, 5)]
            self.run()
            twoF_SUM = np.sum(self.data[:, -1])
    
            self.SSBprec = lalpulsar.SSBPREC_NO_SPIN
            self.set_out_file('SSBPREC_NO_SPIN')
            self.F0s = [self.par['F0']+j/lal.DAYSID_SI
                        for j in range(-self.m0, self.m0+1)]
            self.run()
            twoFstar_SUM = np.sum(self.data[:, -1])
    
            self.set_out_file('SSBPREC_NO_SPIN_TERRESTRIAL')
            self.F0s = [self.par['F0']+j/lal.DAYJUL_SI
                        for j in range(-self.m0, self.m0+1)]
            self.run()
            twoFstar_SUM_terrestrial = np.sum(self.data[:, -1])
    
            return self.m0, twoF_SUM, twoFstar_SUM, twoFstar_SUM_terrestrial