Skip to content
Snippets Groups Projects
Select Git revision
  • f4c88a649d0903c7a68c66af10aa7bf860ac5ca0
  • master default protected
2 results

RDGW150914_emcee.ipynb

Blame
  • Forked from Xisco Jimenez Forteza / RDStackingProject
    Source project has a limited visibility.
    pyfstat.py 50.58 KiB
    """ Classes for various types of searches using ComputeFstatistic """
    import os
    import sys
    import itertools
    import logging
    import argparse
    import copy
    import glob
    import inspect
    from functools import wraps
    
    import numpy as np
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    import emcee
    import corner
    import dill as pickle
    import lalpulsar
    
    try:
        from ephemParams import earth_ephem, sun_ephem
    except (IOError, ImportError):
        logging.warning('No ephemParams.py file found, or it does not contain '
                        'earth_ephem and sun_ephem, please provide the paths when '
                        'initialising searches')
        earth_ephem = None
        sun_ephem = None
    
    plt.style.use('paper')
    
    
    def initializer(func):
        """ Automatically assigns the parameters to self"""
        names, varargs, keywords, defaults = inspect.getargspec(func)
    
        @wraps(func)
        def wrapper(self, *args, **kargs):
            for name, arg in list(zip(names[1:], args)) + list(kargs.items()):
                setattr(self, name, arg)
    
            for name, default in zip(reversed(names), reversed(defaults)):
                if not hasattr(self, name):
                    setattr(self, name, default)
    
            func(self, *args, **kargs)
    
        return wrapper
    
    
    def read_par(label, outdir):
        filename = '{}/{}.par'.format(outdir, label)
        d = {}
        with open(filename, 'r') as f:
            for line in f:
                key, val = line.rstrip('\n').split(' = ')
                d[key] = np.float64(val)
        return d
    
    parser = argparse.ArgumentParser()
    parser.add_argument("-q", "--quite", help="Decrease output verbosity",
                        action="store_true")
    parser.add_argument("-c", "--clean", help="Don't use cached data",
                        action="store_true")
    parser.add_argument('unittest_args', nargs='*')
    args, unknown = parser.parse_known_args()
    sys.argv[1:] = args.unittest_args
    
    if args.quite:
        log_level = logging.WARNING
    else:
        log_level = logging.DEBUG
    
    logging.basicConfig(level=log_level,
                        format='%(asctime)s %(levelname)-8s: %(message)s',
                        datefmt='%H:%M')
    
    
    class BaseSearchClass(object):
    
        earth_ephem_default = earth_ephem
        sun_ephem_default = sun_ephem
    
        def shift_matrix(self, n, dT):
            """ Generate the shift matrix """
            m = np.zeros((n, n))
            factorial = np.math.factorial
            for i in range(n):
                for j in range(n):
                    if i == j:
                        m[i, j] = 1.0
                    elif i > j:
                        m[i, j] = 0.0
                    else:
                        if i == 0:
                            m[i, j] = 2*np.pi*float(dT)**(j-i) / factorial(j-i)
                        else:
                            m[i, j] = float(dT)**(j-i) / factorial(j-i)
    
            return m
    
        def shift_coefficients(self, theta, dT):
            """ Shift a set of coefficients by dT
    
            Parameters
            ----------
            theta: array-like, shape (n,)
                vector of the expansion coefficients to transform starting from the
                lowest degree e.g [phi, F0, F1,...]
            dT: float
                difference between the two reference times as tref_new - tref_old
    
            Returns
            -------
            theta_new: array-like shape (n,)
                vector of the coefficients as evaluate as the new reference time
            """
            n = len(theta)
            m = self.shift_matrix(n, dT)
            return np.dot(m, theta)
    
        # Rewrite this to generalise to N glitches, then use everywhere!
        def calculate_thetas(self, theta, delta_thetas, tbounds):
            """ Calculates the set of coefficients for the post-glitch signal """
            thetas = [theta]
            for i, dt in enumerate(delta_thetas):
                pre_theta_at_ith_glitch = self.shift_coefficients(
                    thetas[i], tbounds[i+1] - self.tref)
                post_theta_at_ith_glitch = pre_theta_at_ith_glitch + dt
                thetas.append(self.shift_coefficients(
                    post_theta_at_ith_glitch, self.tref - tbounds[i+1]))
            return thetas
    
    
    class FullyCoherentNarrowBandSearch(BaseSearchClass):
        """ Search over a narrow band of F0, F1, and F2 """
    
        @initializer
        def __init__(self, label, outdir, sftlabel=None, sftdir=None,
                     tglitch=None, tref=None, tstart=None, Alpha=None, Delta=None,
                     duration=None, Writer=None):
            """
            Parameters
            ----------
            label, outdir: str
                A label and directory to read/write data from/to
            sftlabel, sftdir: str
                A label and directory in which to find the relevant sft file. If
                None use label and outdir
            """
            if self.sftlabel is None:
                self.sftlabel = self.label
            if self.sftdir is None:
                self.sftdir = self.outdir
            self.tend = self.tstart + self.duration
            self.calculate_best_fit_F0_and_F1(Writer)
            self.fs_file_name = "{}/{}_FS.dat".format(self.outdir, self.label)
    
        def calculate_best_fit_F0_and_F1(self, writer):
            R = (writer.tglitch - writer.tstart) / float(writer.duration)
            self.F0_min = (writer.F0 + writer.delta_F0*(1-R)**2*(2*R+1) +
                           3*writer.delta_phi*R*(1-R)/np.pi/writer.duration +
                           # .5*writer.duration*writer.F1*(1-R) +
                           .5*writer.duration*writer.delta_F1*(1-R)**3*(1+R))
    
            self.F1_min = (writer.F1 + writer.delta_F1*(1-R)**3*(6*R**2+3*R+1) +
                           30*writer.delta_F0*R**2*(1-R)**2/writer.duration +
                           30*writer.delta_phi*R*(1-R)*(2*R-1)/np.pi/writer.duration**2
                           )
    
        def get_grid_setup(self, m, n, searchF2=False):
            """ Calc. the grid parameters of bands given the metric-mismatch
    
            Parameters
            ----------
            m: float in [0, 1]
                The mismatch spacing between adjacent grid points
            n: int
                Number of grid points to search
    
            """
            DeltaF0 = np.sqrt(12 * m) / (np.pi * self.duration)
            DeltaF1 = np.sqrt(720 * m) / (np.pi * self.duration**2.0)
            DeltaF2 = np.sqrt(100800 * m) / (np.pi * self.duration**3.0)
    
            # Calculate the width of bands given n
            F0Band = n * DeltaF0
            F1Band = n * DeltaF1
            F2Band = n * DeltaF2
    
            # Search takes the lowest frequency in the band
            F0_bottom = self.F0_min - .5 * F0Band
            F1_bottom = self.F1_min - .5 * F1Band
            if searchF2:
                F2_bottom = self.F2_min-.5*self.F2Band  # Not yet implemented
            else:
                F2_bottom = 0  # Broken functionality
                F2Band = 0
    
            Messg = ["Automated search for {}:".format(self.label),
                     "Grid parameters : m={}, n={}".format(m, n),
                     "Reference time: {}".format(self.tref),
                     "Analytic best-fit values : {}, {}".format(
                         self.F0_min, self.F1_min),
                     "F0Band : {} -- {}".format(F0_bottom,
                                                F0_bottom + F0Band),
                     "F1Band : {} -- {}".format(F1_bottom,
                                                F1_bottom + F1Band),
                     "F2Band : {} -- {}".format(F2_bottom,
                                                F2_bottom + F2Band),
                     ]
            logging.info("\n  ".join(Messg))
    
            return (F0_bottom, DeltaF0, F0Band,
                    F1_bottom, DeltaF1, F1Band,
                    F2_bottom, DeltaF2, F2Band)
    
        def run_computefstatistic_slow(self, m, n, search_F2=False):
            """ Compute the f statistic fully-coherently over a grid """
    
            (F0_bottom, DeltaF0, F0Band,
             F1_bottom, DeltaF1, F1Band,
             F2_bottom, DeltaF2, F2Band) = self.get_grid_setup(m, n)
    
            c_l = []
            c_l.append("lalapps_ComputeFstatistic_v2")
            c_l.append("--Freq={}".format(F0_bottom))
            c_l.append("--dFreq={}".format(DeltaF0))
            c_l.append("--FreqBand={}".format(F0Band))
    
            c_l.append("--f1dot={}".format(F1_bottom))
            c_l.append("--df1dot={}".format(DeltaF1))
            c_l.append("--f1dotBand={}".format(F1Band))
    
            if search_F2:
                c_l.append("--f2dot={}".format(F2_bottom))
                c_l.append("--df2dot={}".format(DeltaF2))
                c_l.append("--f2dotBand={}".format(F2Band))
            else:
                c_l.append("--f2dot={}".format(F2_bottom))
    
            c_l.append("--DataFiles='{}'".format(
                self.outdir+"/*SFT_"+self.label+"*sft"))
    
            c_l.append("--refTime={:10.6f}".format(self.tref))
            c_l.append("--outputFstat='{}'".format(self.fs_file_name))
    
            c_l.append("--Alpha={}".format(self.Alpha))
            c_l.append("--Delta={}".format(self.Delta))
    
            c_l.append("--minStartTime={}".format(self.tstart))
            c_l.append("--maxStartTime={}".format(self.tend))
    
            logging.info("Executing: " + " ".join(c_l) + "\n")
            os.system(" ".join(c_l))
    
            self.read_in_fstat()
    
        def run_computefstatistic(self, dFreq=0, numFreqBins=1):
            """ Compute the f statistic fully-coherently over a grid """
    
            constraints = lalpulsar.SFTConstraints()
            FstatOptionalArgs = lalpulsar.FstatOptionalArgsDefaults
            minCoverFreq = 29
            maxCoverFreq = 31
    
            SFTCatalog = lalpulsar.SFTdataFind(
                    self.sftdir+'/*_'+self.sftlabel+"*sft", constraints)
    
            ephemerides = lalpulsar.InitBarycenter(
                '~/lalsuite-install/share/lalpulsar/earth00-19-DE421.dat.gz',
                '~/lalsuite-install/share/lalpulsar/sun00-19-DE421.dat.gz')
    
            whatToCompute = lalpulsar.FSTATQ_2F
            FstatInput = lalpulsar.CreateFstatInput(SFTCatalog,
                                                    minCoverFreq,
                                                    maxCoverFreq,
                                                    dFreq,
                                                    ephemerides,
                                                    FstatOptionalArgs
                                                    )
    
            PulsarDopplerParams = lalpulsar.PulsarDopplerParams()
            PulsarDopplerParams.refTime = self.tref
            PulsarDopplerParams.Alpha = self.Alpha
            PulsarDopplerParams.Delta = self.Delta
            PulsarDopplerParams.fkdot = np.array([self.F0_min-dFreq*numFreqBins/2.,
                                                  self.F1_min, 0, 0, 0, 0, 0])
    
            FstatResults = lalpulsar.FstatResults()
            lalpulsar.ComputeFstat(FstatResults,
                                   FstatInput,
                                   PulsarDopplerParams,
                                   numFreqBins,
                                   whatToCompute
                                   )
            self.search_F0 = (np.linspace(0, dFreq * numFreqBins, numFreqBins)
                              + FstatResults.doppler.fkdot[0])
            self.search_F1 = np.zeros(numFreqBins) + self.F1_min
            self.search_F2 = np.zeros(numFreqBins) + 0
            self.search_FS = FstatResults.twoF
    
        def read_in_fstat(self):
            """
            Read in data from *_FS.dat file as produced by ComputeFStatistic_v2
            """
    
            data = np.genfromtxt(self.fs_file_name, comments="%")
    
            # If none of the components are varying:
            if data.ndim == 1:
                self.search_F0 = data[0]
                self.search_F1 = data[3]
                self.search_F2 = data[4]
                self.search_FS = data[6]
                return
    
            search_F0 = data[:, 0]
            search_F1 = data[:, 3]
            search_F2 = data[:, 4]
            search_FS = data[:, 6]
    
            NF0 = len(np.unique(search_F0))
            NF1 = len(np.unique(search_F1))
            NF2 = len(np.unique(search_F2))
    
            shape = (NF2, NF1, NF0)
            self.data_shape = shape
            self.search_F0 = np.squeeze(np.reshape(search_F0,
                                        newshape=shape).transpose())
            self.search_F1 = np.squeeze(np.reshape(search_F1,
                                        newshape=shape).transpose())
            self.search_F2 = np.squeeze(np.reshape(search_F2,
                                        newshape=shape).transpose())
            self.search_FS = np.squeeze(np.reshape(search_FS,
                                        newshape=shape).transpose())
    
        def get_FS_max(self):
            """ Returns the maximum FS and the corresponding F0, F1, and F2 """
    
            if np.shape(self.search_FS) == ():
                return self.search_F0, self.search_F1, self.search_F2, self.search_FS
            else:
                max_idx = np.unravel_index(self.search_FS.argmax(), self.search_FS.shape)
                return (self.search_F0[max_idx], self.search_F1[max_idx],
                        self.search_F2[max_idx], self.search_FS[max_idx])
    
        def plot_output(self, output_type='FS', perfectlymatched_FS=None,
                        fig=None, ax=None, savefig=False, title=None):
            """
            Plot the output of the *_FS.dat file as a contour plot
    
            Parameters
            ----------
            output_type: str
                one of 'FS', 'rho', or 'mismatch'
            perfectlymatched_FS: float
                the 2F of a perfectly matched signal against which to
                compute the mismatch
    
            """
    
            resF0 = self.search_F0 - self.F0_min
            resF1 = self.search_F1 - self.F1_min
    
            if output_type == 'FS':
                Z = self.search_FS
                zlabel = r'$2\bar{\mathcal{F}}$'
            elif output_type == 'rho':
                Z = self.search_FS - 4
                zlabel = r'\rho^{2}'
            elif output_type == 'mismatch':
                rho2 = self.search_FS - 4
                perfectlymatched_rho2 = perfectlymatched_FS - 4
                if perfectlymatched_FS:
                    Z = 1 - (rho2) / (perfectlymatched_rho2)
                else:
                    raise ValueError('Plotting the mismatch requires a value for'
                                     ' the parameter perfectlymatched_rho2')
                zlabel = 'mismatch'
    
            if ax is None:
                fig, ax = plt.subplots()
            plt.rcParams['font.size'] = 12
    
            pax = ax.pcolormesh(resF0, resF1, Z, cmap=plt.cm.viridis,
                                vmin=0, vmax=1)
            fig.colorbar(pax, label=zlabel, ax=ax)
            ax.set_xlim(np.min(resF0), np.max(resF0))
            ax.set_ylim(np.min(resF1), np.max(resF1))
    
            ax.set_xlabel(r'$f_0 - f_\textrm{min}$')
            ax.set_ylabel(r'$\dot{f}_0 - \dot{f}_\textrm{min}$')
            ax.set_title(self.label)
    
            plt.tight_layout()
            if savefig:
                fig.savefig('output_{}.png'.format(self.label))
    
    
    class SemiCoherentGlitchSearch(BaseSearchClass):
        """ A semi-coherent glitch search
    
        This implements a basic `semi-coherent glitch F-stat in which the data
        is divided into two segments either side of the proposed glitch and the
        fully-coherent F-stat in each segment is averaged to give the semi-coherent
        F-stat
        """
        # Copy methods
        read_in_fstat = FullyCoherentNarrowBandSearch.__dict__['read_in_fstat']
    
        @initializer
        def __init__(self, label, outdir, tref, tstart, tend, sftlabel=None,
                     nglitch=0, sftdir=None, minCoverFreq=29, maxCoverFreq=31,
                     detector=None, earth_ephem=None, sun_ephem=None):
            """
            Parameters
            ----------
            label, outdir: str
                A label and directory to read/write data from/to
            sftlabel, sftdir: str
                A label and directory in which to find the relevant sft file. If
                None use label and outdir
            tref: int
                GPS seconds of the reference time
            minCoverFreq, maxCoverFreq: float
                The min and max cover frequency passed to CreateFstatInput
            detector: str
                Two character reference to the data to use, specify None for no
                contraint
            earth_ephem, sun_ephem: str
                Paths of the two files containing positions of Earth and Sun,
                respectively at evenly spaced times, as passed to CreateFstatInput
                If None defaults defined in BaseSearchClass will be used
            """
    
            if self.sftlabel is None:
                self.sftlabel = self.label
            if self.sftdir is None:
                self.sftdir = self.outdir
            self.fs_file_name = "{}/{}_FS.dat".format(self.outdir, self.label)
            if self.earth_ephem is None:
                self.earth_ephem = self.earth_ephem_default
            if self.sun_ephem is None:
                self.sun_ephem = self.sun_ephem_default
            self.init_computefstatistic_single_point()
    
        def compute_nglitch_fstat(self, F0, F1, F2, Alpha, Delta, *args):
            """ Compute the semi-coherent glitch F-stat """
    
            args = list(args)
            tboundaries = [self.tstart] + args[-self.nglitch:] + [self.tend]
            delta_F0s = [0] + args[-3*self.nglitch:-2*self.nglitch]
            delta_F1s = [0] + args[-2*self.nglitch:-self.nglitch]
            theta = [F0, F1, F2]
            tref = self.tref
    
            twoFSum = 0
            for i in range(self.nglitch+1):
                ts, te = tboundaries[i], tboundaries[i+1]
    
                if i == 0:
                    theta_at_tref = theta
                else:
                    # Issue here - are these correct?
                    delta_theta = np.array([delta_F0s[i], delta_F1s[i], 0])
                    theta_at_glitch = self.shift_coefficients(theta_at_tref,
                                                              te - tref)
                    theta_post_glitch_at_glitch = theta_at_glitch + delta_theta
                    theta_at_tref = self.shift_coefficients(
                        theta_post_glitch_at_glitch, tref - te)
    
                twoFVal = self.run_computefstatistic_single_point(
                    tref, ts, te, theta_at_tref[0], theta_at_tref[1],
                    theta_at_tref[2], Alpha, Delta)
                twoFSum += twoFVal
    
            return twoFSum
    
        def compute_glitch_fstat_single(self, F0, F1, F2, Alpha, Delta, delta_F0,
                                        delta_F1, tglitch):
            """ Compute the semi-coherent glitch F-stat """
    
            theta = [F0, F1, F2]
            delta_theta = [delta_F0, delta_F1, 0]
            tref = self.tref
    
            theta_at_glitch = self.shift_coefficients(theta, tglitch - tref)
            theta_post_glitch_at_glitch = theta_at_glitch + delta_theta
            theta_post_glitch = self.shift_coefficients(
                theta_post_glitch_at_glitch, tref - tglitch)
    
            twoFsegA = self.run_computefstatistic_single_point(
                tref, self.tstart, tglitch, theta[0], theta[1], theta[2], Alpha,
                Delta)
    
            if tglitch == self.tend:
                return twoFsegA
    
            twoFsegB = self.run_computefstatistic_single_point(
                tref, tglitch, self.tend, theta_post_glitch[0],
                theta_post_glitch[1], theta_post_glitch[2], Alpha,
                Delta)
    
            return twoFsegA + twoFsegB
    
        def init_computefstatistic_single_point(self):
            """ Initilisation step of run_computefstatistic for a single point """
    
            logging.info('Initialising SFTCatalog')
            constraints = lalpulsar.SFTConstraints()
            if self.detector:
                constraints.detector = self.detector
            self.sft_filepath = self.sftdir+'/*_'+self.sftlabel+"*sft"
            SFTCatalog = lalpulsar.SFTdataFind(self.sft_filepath, constraints)
            names = list(set([d.header.name for d in SFTCatalog.data]))
            logging.info('Loaded data from detectors {}'.format(names))
    
            logging.info('Initialising ephems')
            ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem)
    
            logging.info('Initialising FstatInput')
            dFreq = 0
            self.whatToCompute = lalpulsar.FSTATQ_ATOMS_PER_DET
            FstatOptionalArgs = lalpulsar.FstatOptionalArgsDefaults
            self.FstatInput = lalpulsar.CreateFstatInput(SFTCatalog,
                                                         self.minCoverFreq,
                                                         self.maxCoverFreq,
                                                         dFreq,
                                                         ephems,
                                                         FstatOptionalArgs
                                                         )
    
            logging.info('Initialising PulsarDoplerParams')
            PulsarDopplerParams = lalpulsar.PulsarDopplerParams()
            PulsarDopplerParams.refTime = self.tref
            PulsarDopplerParams.Alpha = 1
            PulsarDopplerParams.Delta = 1
            PulsarDopplerParams.fkdot = np.array([0, 0, 0, 0, 0, 0, 0])
            self.PulsarDopplerParams = PulsarDopplerParams
    
            logging.info('Initialising FstatResults')
            self.FstatResults = lalpulsar.FstatResults()
    
        def run_computefstatistic_single_point(self, tref, tstart, tend, F0, F1,
                                               F2, Alpha, Delta):
            """ Compute the F-stat fully-coherently at a single point """
    
            numFreqBins = 1
            self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0])
            self.PulsarDopplerParams.Alpha = Alpha
            self.PulsarDopplerParams.Delta = Delta
    
            lalpulsar.ComputeFstat(self.FstatResults,
                                   self.FstatInput,
                                   self.PulsarDopplerParams,
                                   numFreqBins,
                                   self.whatToCompute
                                   )
    
            windowRange = lalpulsar.transientWindowRange_t()
            windowRange.type = lalpulsar.TRANSIENT_RECTANGULAR
            windowRange.t0 = int(tstart)  # TYPE UINT4
            windowRange.t0Band = 0
            windowRange.dt0 = 1
            windowRange.tau = int(tend - tstart)  # TYPE UINT4
            windowRange.tauBand = 0
            windowRange.dtau = 1
            useFReg = False
            FS = lalpulsar.ComputeTransientFstatMap(self.FstatResults.multiFatoms[0],
                                                    windowRange,
                                                    useFReg)
            return 2*FS.F_mn.data[0][0]
    
        def compute_glitch_fstat_slow(self, F0, F1, F2, Alpha, Delta, delta_F0,
                                      delta_F1, tglitch):
            """ Compute the semi-coherent F-stat """
    
            theta = [F0, F1, F2]
            delta_theta = [delta_F0, delta_F1, 0]
            tref = self.tref
    
            theta_at_glitch = self.shift_coefficients(theta, tglitch - tref)
            theta_post_glitch_at_glitch = theta_at_glitch + delta_theta
            theta_post_glitch = self.shift_coefficients(
                theta_post_glitch_at_glitch, tref - tglitch)
    
            FsegA = self.run_computefstatistic_single_point_slow(
                tref, self.tstart, tglitch, theta[0], theta[1], theta[2], Alpha,
                Delta)
            FsegB = self.run_computefstatistic_single_point_slow(
                tref, tglitch, self.tend, theta_post_glitch[0],
                theta_post_glitch[1], theta_post_glitch[2], Alpha,
                Delta)
    
            return (FsegA + FsegB) / 2.
    
        def run_computefstatistic_single_point_slow(self, tref, tstart, tend, F0,
                                                    F1, F2, Alpha, Delta):
            """ Compute the f statistic fully-coherently at a single point """
    
            c_l = []
            c_l.append("lalapps_ComputeFstatistic_v2")
            c_l.append("--Freq={}".format(F0))
            c_l.append("--f1dot={}".format(F1))
            c_l.append("--f2dot={}".format(F2))
    
            c_l.append("--DataFiles='{}'".format(
                self.outdir+"/*SFT_"+self.label+"*sft"))
    
            c_l.append("--refTime={:10.6f}".format(tref))
            c_l.append("--outputFstat='{}'".format(self.fs_file_name))
    
            c_l.append("--Alpha={}".format(Alpha))
            c_l.append("--Delta={}".format(Delta))
    
            c_l.append("--minStartTime={}".format(tstart))
            c_l.append("--maxStartTime={}".format(tend))
    
            logging.info("Executing: " + " ".join(c_l) + "\n")
            os.system(" ".join(c_l))
    
            self.read_in_fstat()
    
            return self.search_FS
    
    
    class MCMCGlitchSearch(BaseSearchClass):
        """ MCMC search using the SemiCoherentGlitchSearch """
        @initializer
        def __init__(self, label, outdir, sftlabel, sftdir, theta, tref, tstart,
                     tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1,
                     nglitch=0, minCoverFreq=29, maxCoverFreq=31, scatter_val=1e-4,
                     betas=None, detector=None, dtglitchmin=20*86400,
                     earth_ephem=None, sun_ephem=None):
            """
            Parameters
            label, outdir: str
                A label and directory to read/write data from/to
            sftlabel, sftdir: str
                A label and directory in which to find the relevant sft file
            theta: dict
                Dictionary of priors and fixed values for the search parameters.
                For each parameters (key of the dict), if it is to be held fixed
                the value should be the constant float, if it is be searched, the
                value should be a dictionary of the prior.
            nglitch: int
                The number of glitches to allow
            tref, tstart, tend: int
                GPS seconds of the reference time, start time and end time
            nsteps: list (m,)
                List specifying the number of steps to take, the last two entries
                give the nburn and nprod of the 'production' run, all entries
                before are for iterative initialisation steps (usually just one)
                e.g. [1000, 1000, 500].
            dtglitchmin: int
                The minimum duration (in seconds) of a segment between two glitches
                or a glitch and the start/end of the data
            nwalkers, ntemps: int
                Number of walkers and temperatures
            minCoverFreq, maxCoverFreq: float
                Minimum and maximum instantaneous frequency which will be covered
                over the SFT time span as passed to CreateFstatInput
            earth_ephem, sun_ephem: str
                Paths of the two files containing positions of Earth and Sun,
                respectively at evenly spaced times, as passed to CreateFstatInput
                If None defaults defined in BaseSearchClass will be used
    
            """
    
            logging.info(('Set-up MCMC search with {} glitches for model {} on'
                          ' data {}').format(self.nglitch, self.label,
                                             self.sftlabel))
            if os.path.isdir(outdir) is False:
                os.mkdir(outdir)
            self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
            self.unpack_input_theta()
            self.ndim = len(self.theta_keys)
            self.sft_filepath = self.sftdir+'/*_'+self.sftlabel+"*sft"
            if earth_ephem is None:
                self.earth_ephem = self.earth_ephem_default
            if sun_ephem is None:
                self.sun_ephem = self.sun_ephem_default
    
            if args.clean and os.path.isfile(self.pickle_path):
                os.rename(self.pickle_path, self.pickle_path+".old")
    
            self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
    
        def inititate_search_object(self):
            logging.info('Setting up search object')
            self.search = SemiCoherentGlitchSearch(
                label=self.label, outdir=self.outdir, sftlabel=self.sftlabel,
                sftdir=self.sftdir, tref=self.tref, tstart=self.tstart,
                tend=self.tend, minCoverFreq=self.minCoverFreq,
                maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
                sun_ephem=self.sun_ephem, detector=self.detector,
                nglitch=self.nglitch)
    
        def logp(self, theta_vals, theta_prior, theta_keys, search):
            if self.nglitch > 1:
                ts = [self.tstart] + theta_vals[-self.nglitch:] + [self.tend]
                if np.array_equal(ts, np.sort(ts)) is False:
                    return -np.inf
                if any(np.diff(ts) < self.dtglitchmin):
                    return -np.inf
    
            H = [self.Generic_lnprior(**theta_prior[key])(p) for p, key in
                 zip(theta_vals, theta_keys)]
            return np.sum(H)
    
        def logl(self, theta, search):
            for j, theta_i in enumerate(self.theta_idxs):
                self.fixed_theta[theta_i] = theta[j]
            FS = search.compute_nglitch_fstat(*self.fixed_theta)
            return FS
    
        def unpack_input_theta(self):
            glitch_keys = ['delta_F0', 'delta_F1', 'tglitch']
            full_glitch_keys = list(np.array(
                [[gk]*self.nglitch for gk in glitch_keys]).flatten())
            full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']+full_glitch_keys
            full_theta_keys_copy = copy.copy(full_theta_keys)
    
            glitch_symbols = ['$\delta f$', '$\delta \dot{f}$', r'$t_{glitch}$']
            full_glitch_symbols = list(np.array(
                [[gs]*self.nglitch for gs in glitch_symbols]).flatten())
            full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
                                   r'$\delta$'] + full_glitch_symbols)
            self.theta_prior = {}
            self.theta_keys = []
            fixed_theta_dict = {}
            for key, val in self.theta.iteritems():
                if type(val) is dict:
                    self.theta_prior[key] = val
                    fixed_theta_dict[key] = 0
                    if key in glitch_keys:
                        for i in range(self.nglitch):
                            self.theta_keys.append(key)
                    else:
                        self.theta_keys.append(key)
                elif type(val) in [float, int, np.float64]:
                    fixed_theta_dict[key] = val
                else:
                    raise ValueError(
                        'Type {} of {} in theta not recognised'.format(
                            type(val), key))
                if key in glitch_keys:
                    for i in range(self.nglitch):
                        full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
                else:
                    full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
    
            if len(full_theta_keys_copy) > 0:
                raise ValueError(('Input dictionary `theta` is missing the'
                                  'following keys: {}').format(
                                      full_theta_keys_copy))
    
            self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
            self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
            self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs]
    
            idxs = np.argsort(self.theta_idxs)
            self.theta_idxs = [self.theta_idxs[i] for i in idxs]
            self.theta_symbols = [self.theta_symbols[i] for i in idxs]
            self.theta_keys = [self.theta_keys[i] for i in idxs]
    
            # Correct for number of glitches in the idxs
            self.theta_idxs = np.array(self.theta_idxs)
            while np.sum(self.theta_idxs[:-1] == self.theta_idxs[1:]) > 0:
                for i, idx in enumerate(self.theta_idxs):
                    if idx in self.theta_idxs[:i]:
                        self.theta_idxs[i] += 1
    
        def check_initial_points(self, p0):
            initial_priors = np.array([
                self.logp(p, self.theta_prior, self.theta_keys, self.search)
                for p in p0[0]])
            number_of_initial_out_of_bounds = sum(initial_priors == -np.inf)
            if number_of_initial_out_of_bounds > 0:
                logging.warning(
                    'Of {} initial values, {} are -np.inf due to the prior'.format(
                        len(initial_priors), number_of_initial_out_of_bounds))
    
        def run(self):
    
            if self.old_data_is_okay_to_use is True:
                logging.warning('Using saved data from {}'.format(
                    self.pickle_path))
                d = self.get_saved_data()
                self.sampler = d['sampler']
                self.samples = d['samples']
                self.lnprobs = d['lnprobs']
                self.lnlikes = d['lnlikes']
                return
    
            self.inititate_search_object()
    
            sampler = emcee.PTSampler(
                self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp,
                logpargs=(self.theta_prior, self.theta_keys, self.search),
                loglargs=(self.search,), betas=self.betas)
    
            p0 = self.GenerateInitial()
            self.check_initial_points(p0)
    
            ninit_steps = len(self.nsteps) - 2
            for j, n in enumerate(self.nsteps[:-2]):
                logging.info('Running {}/{} initialisation with {} steps'.format(
                    j, ninit_steps, n))
                sampler.run_mcmc(p0, n)
    
                fig, axes = self.PlotWalkers(sampler, symbols=self.theta_symbols)
                fig.savefig('{}/{}_init_{}_walkers.png'.format(
                    self.outdir, self.label, j))
    
                p0 = self.get_new_p0(sampler, scatter_val=self.scatter_val)
                self.check_initial_points(p0)
                sampler.reset()
    
            nburn = self.nsteps[-2]
            nprod = self.nsteps[-1]
            logging.info('Running final burn and prod with {} steps'.format(
                nburn+nprod))
            sampler.run_mcmc(p0, nburn+nprod)
    
            fig, axes = self.PlotWalkers(sampler, symbols=self.theta_symbols)
            fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label))
    
            samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
            lnprobs = sampler.lnprobability[0, :, nburn:].reshape((-1))
            lnlikes = sampler.lnlikelihood[0, :, nburn:].reshape((-1))
            self.sampler = sampler
            self.samples = samples
            self.lnprobs = lnprobs
            self.lnlikes = lnlikes
            self.save_data(sampler, samples, lnprobs, lnlikes)
    
        def plot_corner(self, corner_figsize=(7, 7),  deltat=False,
                        add_prior=False, nstds=None, label_offset=0.4, **kwargs):
    
            fig, axes = plt.subplots(self.ndim, self.ndim,
                                     figsize=corner_figsize)
    
            samples_plt = copy.copy(self.samples)
            theta_symbols_plt = copy.copy(self.theta_symbols)
            theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}') for s
                                 in theta_symbols_plt]
    
            if deltat:
                samples_plt[:, self.theta_keys.index('tglitch')] -= self.tref
                theta_symbols_plt[self.theta_keys.index('tglitch')] = (
                    r'$t_{\textrm{glitch}} - t_{\textrm{ref}}$')
    
            if type(nstds) is int and 'range' not in kwargs:
                _range = []
                for j, s in enumerate(samples_plt.T):
                    median = np.median(s)
                    std = np.std(s)
                    _range.append((median - nstds*std, median + nstds*std))
            else:
                _range = None
    
            fig_triangle = corner.corner(samples_plt,
                                         labels=theta_symbols_plt,
                                         fig=fig,
                                         bins=50,
                                         max_n_ticks=4,
                                         plot_contours=True,
                                         plot_datapoints=True,
                                         label_kwargs={'fontsize': 8},
                                         data_kwargs={'alpha': 0.1,
                                                      'ms': 0.5},
                                         range=_range,
                                         **kwargs)
    
            axes_list = fig_triangle.get_axes()
            axes = np.array(axes_list).reshape(self.ndim, self.ndim)
            plt.draw()
            for ax in axes[:, 0]:
                ax.yaxis.set_label_coords(-label_offset, 0.5)
            for ax in axes[-1, :]:
                ax.xaxis.set_label_coords(0.5, -label_offset)
            for ax in axes_list:
                ax.set_rasterized(True)
                ax.set_rasterization_zorder(-10)
            plt.tight_layout(h_pad=0.0, w_pad=0.0)
            fig.subplots_adjust(hspace=0.05, wspace=0.05)
    
            if add_prior:
                self.add_prior_to_corner(axes, samples_plt)
    
            fig_triangle.savefig('{}/{}_corner.png'.format(
                self.outdir, self.label))
    
        def add_prior_to_corner(self, axes, samples):
            for i, key in enumerate(self.theta_keys):
                ax = axes[i][i]
                xlim = ax.get_xlim()
                s = samples[:, i]
                prior = self.Generic_lnprior(**self.theta_prior[key])
                x = np.linspace(s.min(), s.max(), 100)
                ax2 = ax.twinx()
                ax2.get_yaxis().set_visible(False)
                ax2.plot(x, [prior(xi) for xi in x], '-r')
                ax.set_xlim(xlim)
    
        def get_new_p0(self, sampler, scatter_val=1e-3):
            """ Returns new initial positions for walkers are burn0 stage
    
            This returns new positions for all walkers by scattering points about
            the maximum posterior with scale `scatter_val`.
    
            """
            if sampler.chain[:, :, -1, :].shape[0] == 1:
                ntemps_temp = 1
            else:
                ntemps_temp = self.ntemps
            pF = sampler.chain[:, :, -1, :].reshape(
                ntemps_temp, self.nwalkers, self.ndim)[0, :, :]
            lnp = sampler.lnprobability[:, :, -1].reshape(
                self.ntemps, self.nwalkers)[0, :]
            if any(np.isnan(lnp)):
                logging.warning("The sampler has produced nan's")
    
            p = pF[np.nanargmax(lnp)]
            p0 = [[p + scatter_val * p * np.random.randn(self.ndim)
                  for i in xrange(self.nwalkers)] for j in xrange(self.ntemps)]
            if self.nglitch > 1:
                p0 = np.array(p0)
                p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:],
                                                   axis=2)
            return p0
    
        def Generic_lnprior(self, **kwargs):
            """ Return a lambda function of the pdf
    
            Parameters
            ----------
            kwargs: dict
                A dictionary containing 'type' of pdf and shape parameters
    
            """
    
            def logunif(x, a, b):
                above = x < b
                below = x > a
                if type(above) is not np.ndarray:
                    if above and below:
                        return -np.log(b-a)
                    else:
                        return -np.inf
                else:
                    idxs = np.array([all(tup) for tup in zip(above, below)])
                    p = np.zeros(len(x)) - np.inf
                    p[idxs] = -np.log(b-a)
                    return p
    
            def halfnorm(x, loc, scale):
                if x < 0:
                    return -np.inf
                else:
                    return -0.5*((x-loc)**2/scale**2+np.log(0.5*np.pi*scale**2))
    
            def cauchy(x, x0, gamma):
                return 1.0/(np.pi*gamma*(1+((x-x0)/gamma)**2))
    
            def exp(x, x0, gamma):
                if x > x0:
                    return np.log(gamma) - gamma*(x - x0)
                else:
                    return -np.inf
    
            if kwargs['type'] == 'unif':
                return lambda x: logunif(x, kwargs['lower'], kwargs['upper'])
            elif kwargs['type'] == 'halfnorm':
                return lambda x: halfnorm(x, kwargs['loc'], kwargs['scale'])
            elif kwargs['type'] == 'norm':
                return lambda x: -0.5*((x - kwargs['loc'])**2/kwargs['scale']**2
                                       + np.log(2*np.pi*kwargs['scale']**2))
            else:
                logging.info("kwargs:", kwargs)
                raise ValueError("Print unrecognise distribution")
    
        def GenerateRV(self, **kwargs):
            dist_type = kwargs.pop('type')
            if dist_type == "unif":
                return np.random.uniform(low=kwargs['lower'], high=kwargs['upper'])
            if dist_type == "norm":
                return np.random.normal(loc=kwargs['loc'], scale=kwargs['scale'])
            if dist_type == "halfnorm":
                return np.abs(np.random.normal(loc=kwargs['loc'],
                                               scale=kwargs['scale']))
            if dist_type == "lognorm":
                return np.random.lognormal(
                    mean=kwargs['loc'], sigma=kwargs['scale'])
            else:
                raise ValueError("dist_type {} unknown".format(dist_type))
    
        def PlotWalkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0,
                        start=None, stop=None, draw_vline=None):
            """ Plot all the chains from a sampler """
    
            shape = sampler.chain.shape
            if len(shape) == 3:
                nwalkers, nsteps, ndim = shape
                chain = sampler.chain[:, :, :]
            if len(shape) == 4:
                ntemps, nwalkers, nsteps, ndim = shape
                if temp < ntemps:
                    logging.info("Plotting temperature {} chains".format(temp))
                else:
                    raise ValueError(("Requested temperature {} outside of"
                                      "available range").format(temp))
                chain = sampler.chain[temp, :, :, :]
    
            with plt.style.context(('classic')):
                fig, axes = plt.subplots(ndim, 1, sharex=True, figsize=(8, 4*ndim))
    
                if ndim > 1:
                    for i in range(ndim):
                        axes[i].plot(chain[:, start:stop, i].T, color="k",
                                     alpha=alpha)
                        if symbols:
                            axes[i].set_ylabel(symbols[i])
                        if draw_vline is not None:
                            axes[i].axvline(draw_vline, lw=2, ls="--")
    
            return fig, axes
    
        def GenerateInitial(self):
            """ Generate a set of init vals for the walkers based on the prior """
            p0 = [[[self.GenerateRV(**self.theta_prior[key])
                    for key in self.theta_keys]
                   for i in range(self.nwalkers)]
                  for j in range(self.ntemps)]
    
            # Order the times to start the right way around
            if self.nglitch > 1:
                p0 = np.array(p0)
                p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:],
                                                   axis=2)
            return p0
    
        def get_save_data_dictionary(self):
            d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
                     ntemps=self.ntemps, theta_keys=self.theta_keys,
                     theta_prior=self.theta_prior, scatter_val=self.scatter_val)
            return d
    
        def save_data(self, sampler, samples, lnprobs, lnlikes):
            d = self.get_save_data_dictionary()
            d['sampler'] = sampler
            d['samples'] = samples
            d['lnprobs'] = lnprobs
            d['lnlikes'] = lnlikes
    
            if os.path.isfile(self.pickle_path):
                logging.info('Saving backup of {} as {}.old'.format(
                    self.pickle_path, self.pickle_path))
                os.rename(self.pickle_path, self.pickle_path+".old")
            with open(self.pickle_path, "wb") as File:
                pickle.dump(d, File)
    
        def get_list_of_matching_sfts(self):
            matches = glob.glob(self.sft_filepath)
            if len(matches) > 0:
                return matches
            else:
                raise IOError('No sfts found matching {}'.format(
                    self.sft_filepath))
    
        def get_saved_data(self):
            with open(self.pickle_path, "r") as File:
                d = pickle.load(File)
            return d
    
        def check_old_data_is_okay_to_use(self):
            if os.path.isfile(self.pickle_path) is False:
                logging.info('No pickled data found')
                return False
    
            oldest_sft = min([os.path.getmtime(f) for f in
                              self.get_list_of_matching_sfts()])
            if os.path.getmtime(self.pickle_path) < oldest_sft:
                logging.info('Pickled data outdates sft files')
                return False
    
            old_d = self.get_saved_data().copy()
            new_d = self.get_save_data_dictionary().copy()
    
            old_d.pop('samples')
            old_d.pop('sampler')
            old_d.pop('lnprobs')
            old_d.pop('lnlikes')
    
            mod_keys = []
            for key in new_d.keys():
                if key in old_d:
                    if new_d[key] != old_d[key]:
                        mod_keys.append((key, old_d[key], new_d[key]))
                else:
                    raise ValueError('Keys do not match')
    
            if len(mod_keys) == 0:
                return True
            else:
                logging.warning("Saved data differs from requested")
                logging.info("Differences found in following keys:")
                for key in mod_keys:
                    if len(key) == 3:
                        if np.isscalar(key[1]) or key[0] == 'nsteps':
                            logging.info("{} : {} -> {}".format(*key))
                        else:
                            logging.info(key[0])
                    else:
                        logging.info(key)
                return False
    
        def get_max_twoF(self, threshold=0.05):
            """ Returns the max 2F sample and the corresponding 2F value
    
            Note: the sample is returned as a dictionary along with an estimate of
            the standard deviation calculated from the std of all samples with a
            twoF within `threshold` (relative) to the max twoF
    
            """
            if any(np.isposinf(self.lnlikes)):
                logging.info('twoF values contain positive infinite values')
            if any(np.isneginf(self.lnlikes)):
                logging.info('twoF values contain negative infinite values')
            if any(np.isnan(self.lnlikes)):
                logging.info('twoF values contain nan')
            idxs = np.isfinite(self.lnlikes)
            jmax = np.nanargmax(self.lnlikes[idxs])
            maxtwoF = self.lnlikes[jmax]
            d = {}
            close_idxs = abs((maxtwoF - self.lnlikes[idxs]) / maxtwoF) < threshold
            for i, k in enumerate(self.theta_keys):
                base_key = copy.copy(k)
                ng = 1
                while k in d:
                    k = base_key + '_{}'.format(ng)
                d[k] = self.samples[jmax][i]
    
                s = self.samples[:, i][close_idxs]
                d[k + '_std'] = np.std(s)
            return d, maxtwoF
    
        def get_median_stds(self):
            """ Returns a dict of the median and std of all production samples """
            d = {}
            for s, k in zip(self.samples.T, self.theta_keys):
                d[k] = np.median(s)
                d[k+'_std'] = np.std(s)
            return d
    
        def write_par(self, method='med'):
            """ Writes a .par of the best-fit params with an estimated std """
            logging.info('Writing {}/{}.par using the {} method'.format(
                self.outdir, self.label, method))
            if method == 'med':
                median_std_d = self.get_median_stds()
                filename = '{}/{}.par'.format(self.outdir, self.label)
                with open(filename, 'w+') as f:
                    for key, val in median_std_d.iteritems():
                        f.write('{} = {:1.16e}\n'.format(key, val))
            if method == 'twoFmax':
                max_twoF_d, _ = self.get_max_twoF()
                filename = '{}/{}.par'.format(self.outdir, self.label)
                with open(filename, 'w+') as f:
                    for key, val in max_twoF_d.iteritems():
                        f.write('{} = {:1.16e}\n'.format(key, val))
    
        def print_summary(self):
            d, max_twoF = self.get_max_twoF()
            print('Max twoF: {}'.format(max_twoF))
            for k in np.sort(d.keys()):
                if 'std' not in k:
                    print('{:10s} = {:1.9e} +/- {:1.9e}'.format(
                        k, d[k], d[k+'_std']))
    
    
    class GridGlitchSearch(BaseSearchClass):
        """ Gridded search using the SemiCoherentGlitchSearch """
        @initializer
        def __init__(self, label, outdir, sftlabel=None, sftdir=None, F0s=[0],
                     F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None,
                     Alphas=[0], Deltas=[0], tref=None, tstart=None, tend=None,
                     minCoverFreq=29, maxCoverFreq=31, write_after=1000,
                     earth_ephem=None, sun_ephem=None):
            """
            Parameters
            label, outdir: str
                A label and directory to read/write data from/to
            sftlabel, sftdir: str
                A label and directory in which to find the relevant sft file
            F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple
                Length 3 tuple describing the grid for each parameter, e.g
                [F0min, F0max, dF0], for a fixed value simply give [F0].
            tref, tstart, tend: int
                GPS seconds of the reference time, start time and end time
            minCoverFreq, maxCoverFreq: float
                Minimum and maximum instantaneous frequency which will be covered
                over the SFT time span as passed to CreateFstatInput
            earth_ephem, sun_ephem: str
                Paths of the two files containing positions of Earth and Sun,
                respectively at evenly spaced times, as passed to CreateFstatInput
                If None defaults defined in BaseSearchClass will be used
    
            """
            if tglitchs is None:
                self.tglitchs = [self.tend]
            if sftlabel is None:
                self.sftlabel = self.label
            if sftdir is None:
                self.sftdir = self.outdir
            if earth_ephem is None:
                self.earth_ephem = self.earth_ephem_default
            if sun_ephem is None:
                self.sun_ephem = self.sun_ephem_default
    
            self.search = SemiCoherentGlitchSearch(
                label=label, outdir=outdir, sftlabel=sftlabel, sftdir=sftdir,
                tref=tref, tstart=tstart, tend=tend, minCoverFreq=minCoverFreq,
                maxCoverFreq=maxCoverFreq, earth_ephem=self.earth_ephem,
                sun_ephem=self.sun_ephem)
    
            if os.path.isdir(outdir) is False:
                os.mkdir(outdir)
            self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label)
            self.keys = ['F0', 'F1', 'F2', 'delta_F0', 'delta_F1', 'tglitch',
                         'Alpha', 'Delta']
    
        def get_array_from_tuple(self, x):
            if len(x) == 1:
                return np.array(x)
            else:
                return np.arange(x[0], x[1]*(1+1e-15), x[2])
    
        def get_input_data_array(self):
            arrays = []
            for tup in (self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas,
                        self.delta_F0s, self.delta_F1s, self.tglitchs):
                arrays.append(self.get_array_from_tuple(tup))
    
            input_data = []
            for vals in itertools.product(*arrays):
                input_data.append(vals)
    
            self.arrays = arrays
            self.input_data = np.array(input_data)
    
        def check_old_data_is_okay_to_use(self):
            if os.path.isfile(self.out_file) is False:
                logging.info('No old data found, continuing with grid search')
                return False
            data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
            if np.all(data[:, 0:-1] == self.input_data):
                logging.info(
                    'Old data found with matching input, no search performed')
                return data
            else:
                logging.info(
                    'Old data found, input differs, continuing with grid search')
                return False
    
        def run(self):
            self.get_input_data_array()
            old_data = self.check_old_data_is_okay_to_use()
            if old_data is not False:
                self.data = old_data
                return
    
            logging.info('Total number of grid points is {}'.format(
                len(self.input_data)))
    
            counter = 0
            data = []
            for vals in self.input_data:
                FS = self.search.compute_glitch_fstat(*vals)
                data.append(list(vals) + [FS])
    
                if counter > self.write_after:
                    np.savetxt(self.out_file, data, delimiter=' ')
                    counter = 0
                    data = []
    
            logging.info('Saving data to {}'.format(self.out_file))
            np.savetxt(self.out_file, data, delimiter=' ')
            self.data = np.array(data)
    
        def plot_2D(self, xkey, ykey):
            fig, ax = plt.subplots()
            xidx = self.keys.index(xkey)
            yidx = self.keys.index(ykey)
            x = np.unique(self.data[:, xidx])
            y = np.unique(self.data[:, yidx])
            z = self.data[:, -1]
    
            X, Y = np.meshgrid(x, y)
            Z = z.reshape(X.shape)
    
            pax = ax.pcolormesh(X, Y, Z, cmap=plt.cm.viridis)
            fig.colorbar(pax)
            ax.set_xlim(x[0], x[-1])
            ax.set_ylim(y[0], y[-1])
            ax.set_xlabel(xkey)
            ax.set_ylabel(ykey)
    
            fig.tight_layout()
            fig.savefig('{}/{}_2D.png'.format(self.outdir, self.label))
    
        def get_max_twoF(self):
            twoF = self.data[:, -1]
            return np.max(twoF)