Skip to content
Snippets Groups Projects
Select Git revision
  • 4e7969841a198ef16f53b25359ed8d5d31242c41
  • master default protected
  • 72-improve-docs-for_optimal_setup
  • os-path-join
  • develop-GA
  • add-higher-spindown-components
  • v1.3
  • v1.2
  • v1.1.2
  • v1.1.0
  • v1.0.1
11 results

transient_search_using_MCMC_make_simulated_data.py

Blame
  • grid_based_searches.py 11.23 KiB
    """ Searches using grid-based methods """
    
    import os
    import logging
    import itertools
    from collections import OrderedDict
    
    import numpy as np
    import matplotlib
    import matplotlib.pyplot as plt
    
    import helper_functions
    from core import BaseSearchClass, ComputeFstat, SemiCoherentGlitchSearch
    from core import tqdm, args, earth_ephem, sun_ephem
    
    
    class GridSearch(BaseSearchClass):
        """ Gridded search using ComputeFstat """
        @helper_functions.initializer
        def __init__(self, label, outdir, sftfilepath, F0s=[0], F1s=[0], F2s=[0],
                     Alphas=[0], Deltas=[0], tref=None, minStartTime=None,
                     maxStartTime=None, BSGL=False, minCoverFreq=None,
                     maxCoverFreq=None, earth_ephem=None, sun_ephem=None,
                     detectors=None):
            """
            Parameters
            ----------
            label, outdir: str
                A label and directory to read/write data from/to
            sftfilepath: str
                File patern to match SFTs
            F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple
                Length 3 tuple describing the grid for each parameter, e.g
                [F0min, F0max, dF0], for a fixed value simply give [F0].
            tref, minStartTime, maxStartTime: int
                GPS seconds of the reference time, start time and end time
    
            For all other parameters, see `pyfstat.ComputeFStat` for details
            """
    
            if earth_ephem is None:
                self.earth_ephem = self.earth_ephem_default
            if sun_ephem is None:
                self.sun_ephem = self.sun_ephem_default
    
            if os.path.isdir(outdir) is False:
                os.mkdir(outdir)
            self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label)
            self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
    
        def inititate_search_object(self):
            logging.info('Setting up search object')
            self.search = ComputeFstat(
                tref=self.tref, sftfilepath=self.sftfilepath,
                minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
                earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
                detectors=self.detectors, transient=False,
                minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
                BSGL=self.BSGL)
    
        def get_array_from_tuple(self, x):
            if len(x) == 1:
                return np.array(x)
            else:
                return np.arange(x[0], x[1]*(1+1e-15), x[2])
    
        def get_input_data_array(self):
            arrays = []
            for tup in ([self.minStartTime], [self.maxStartTime], self.F0s, self.F1s, self.F2s,
                        self.Alphas, self.Deltas):
                arrays.append(self.get_array_from_tuple(tup))
    
            input_data = []
            for vals in itertools.product(*arrays):
                input_data.append(vals)
    
            self.arrays = arrays
            self.input_data = np.array(input_data)
    
        def check_old_data_is_okay_to_use(self):
            if args.clean:
                return False
            if os.path.isfile(self.out_file) is False:
                logging.info('No old data found, continuing with grid search')
                return False
            data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
            if np.all(data[:, 0:-1] == self.input_data):
                logging.info(
                    'Old data found with matching input, no search performed')
                return data
            else:
                logging.info(
                    'Old data found, input differs, continuing with grid search')
                return False
    
        def run(self, return_data=False):
            self.get_input_data_array()
            old_data = self.check_old_data_is_okay_to_use()
            if old_data is not False:
                self.data = old_data
                return
    
            self.inititate_search_object()
    
            logging.info('Total number of grid points is {}'.format(
                len(self.input_data)))
    
            data = []
            for vals in tqdm(self.input_data):
                FS = self.search.run_computefstatistic_single_point(*vals)
                data.append(list(vals) + [FS])
    
            data = np.array(data)
            if return_data:
                return data
            else:
                logging.info('Saving data to {}'.format(self.out_file))
                np.savetxt(self.out_file, data, delimiter=' ')
                self.data = data
    
        def convert_F0_to_mismatch(self, F0, F0hat, Tseg):
            DeltaF0 = F0[1] - F0[0]
            m_spacing = (np.pi*Tseg*DeltaF0)**2 / 12.
            N = len(F0)
            return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing)
    
        def convert_F1_to_mismatch(self, F1, F1hat, Tseg):
            DeltaF1 = F1[1] - F1[0]
            m_spacing = (np.pi*Tseg**2*DeltaF1)**2 / 720.
            N = len(F1)
            return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing)
    
        def add_mismatch_to_ax(self, ax, x, y, xkey, ykey, xhat, yhat, Tseg):
            axX = ax.twiny()
            axX.zorder = -10
            axY = ax.twinx()
            axY.zorder = -10
    
            if xkey == 'F0':
                m = self.convert_F0_to_mismatch(x, xhat, Tseg)
                axX.set_xlim(m[0], m[-1])
    
            if ykey == 'F1':
                m = self.convert_F1_to_mismatch(y, yhat, Tseg)
                axY.set_ylim(m[0], m[-1])
    
        def plot_1D(self, xkey):
            fig, ax = plt.subplots()
            xidx = self.keys.index(xkey)
            x = np.unique(self.data[:, xidx])
            z = self.data[:, -1]
            plt.plot(x, z)
            fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
    
        def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None,
                    add_mismatch=None, xN=None, yN=None, flat_keys=[],
                    rel_flat_idxs=[], flatten_method=np.max, title=None,
                    predicted_twoF=None, cm=None, cbarkwargs={}):
            """ Plots a 2D grid of 2F values
    
            Parameters
            ----------
            add_mismatch: tuple (xhat, yhat, Tseg)
                If not None, add a secondary axis with the metric mismatch from the
                point xhat, yhat with duration Tseg
            flatten_method: np.max
                Function to use in flattening the flat_keys
            """
            if ax is None:
                fig, ax = plt.subplots()
            xidx = self.keys.index(xkey)
            yidx = self.keys.index(ykey)
            flat_idxs = [self.keys.index(k) for k in flat_keys]
    
            x = np.unique(self.data[:, xidx])
            y = np.unique(self.data[:, yidx])
            flat_vals = [np.unique(self.data[:, j]) for j in flat_idxs]
            z = self.data[:, -1]
    
            Y, X = np.meshgrid(y, x)
            shape = [len(x), len(y)] + [len(v) for v in flat_vals]
            Z = z.reshape(shape)
    
            if len(rel_flat_idxs) > 0:
                Z = flatten_method(Z, axis=tuple(rel_flat_idxs))
    
            if predicted_twoF:
                Z = (predicted_twoF - Z) / (predicted_twoF + 4)
                if cm is None:
                    cm = plt.cm.viridis_r
            else:
                if cm is None:
                    cm = plt.cm.viridis
    
            pax = ax.pcolormesh(X, Y, Z, cmap=cm, vmin=vmin, vmax=vmax)
            cb = plt.colorbar(pax, ax=ax, **cbarkwargs)
            cb.set_label('$2\mathcal{F}$')
    
            if add_mismatch:
                self.add_mismatch_to_ax(ax, x, y, xkey, ykey, *add_mismatch)
    
            ax.set_xlim(x[0], x[-1])
            ax.set_ylim(y[0], y[-1])
            labels = {'F0': '$f$', 'F1': '$\dot{f}$'}
            ax.set_xlabel(labels[xkey])
            ax.set_ylabel(labels[ykey])
    
            if title:
                ax.set_title(title)
    
            if xN:
                ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(xN))
            if yN:
                ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(yN))
    
            if save:
                fig.tight_layout()
                fig.savefig('{}/{}_2D.png'.format(self.outdir, self.label))
            else:
                return ax
    
        def get_max_twoF(self):
            twoF = self.data[:, -1]
            idx = np.argmax(twoF)
            v = self.data[idx, :]
            d = OrderedDict(minStartTime=v[0], maxStartTime=v[1], F0=v[2], F1=v[3],
                            F2=v[4], Alpha=v[5], Delta=v[6], twoF=v[7])
            return d
    
        def print_max_twoF(self):
            d = self.get_max_twoF()
            print('Max twoF values for {}:'.format(self.label))
            for k, v in d.iteritems():
                print('  {}={}'.format(k, v))
    
    
    class GridUniformPriorSearch():
        def __init__(self, theta_prior, NF0, NF1, label, outdir, sftfilepath,
                     tref, minStartTime, maxStartTime, BSGL=False, detectors=None,
                     **kwargs):
            dF0 = (theta_prior['F0']['upper'] - theta_prior['F0']['lower'])/NF0
            dF1 = (theta_prior['F1']['upper'] - theta_prior['F1']['lower'])/NF1
            F0s = [theta_prior['F0']['lower'], theta_prior['F0']['upper'], dF0]
            F1s = [theta_prior['F1']['lower'], theta_prior['F1']['upper'], dF1]
            search = GridSearch(
                label, outdir, sftfilepath, F0s=F0s, F1s=F1s, tref=tref,
                Alphas=[theta_prior['Alpha']], Deltas=[theta_prior['Delta']],
                minStartTime=minStartTime, maxStartTime=maxStartTime, BSGL=BSGL,
                detectors=detectors)
            search.run()
            search.plot_2D('F0', 'F1', **kwargs)
    
    
    class GridGlitchSearch(GridSearch):
        """ Grid search using the SemiCoherentGlitchSearch """
        @helper_functions.initializer
        def __init__(self, label, outdir, sftfilepath=None, F0s=[0],
                     F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None,
                     Alphas=[0], Deltas=[0], tref=None, minStartTime=None,
                     maxStartTime=None, minCoverFreq=None, maxCoverFreq=None,
                     write_after=1000, earth_ephem=None, sun_ephem=None):
    
            """
            Parameters
            ----------
            label, outdir: str
                A label and directory to read/write data from/to
            sftfilepath: str
                File patern to match SFTs
            F0s, F1s, F2s, delta_F0s, delta_F1s, tglitchs, Alphas, Deltas: tuple
                Length 3 tuple describing the grid for each parameter, e.g
                [F0min, F0max, dF0], for a fixed value simply give [F0].
            tref, minStartTime, maxStartTime: int
                GPS seconds of the reference time, start time and end time
    
            For all other parameters, see pyfstat.ComputeFStat.
            """
            if tglitchs is None:
                self.tglitchs = [self.maxStartTime]
            if earth_ephem is None:
                self.earth_ephem = self.earth_ephem_default
            if sun_ephem is None:
                self.sun_ephem = self.sun_ephem_default
    
            self.search = SemiCoherentGlitchSearch(
                label=label, outdir=outdir, sftfilepath=self.sftfilepath,
                tref=tref, minStartTime=minStartTime, maxStartTime=maxStartTime,
                minCoverFreq=minCoverFreq, maxCoverFreq=maxCoverFreq,
                earth_ephem=self.earth_ephem, sun_ephem=self.sun_ephem,
                BSGL=self.BSGL)
    
            if os.path.isdir(outdir) is False:
                os.mkdir(outdir)
            self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label)
            self.keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta', 'delta_F0',
                         'delta_F1', 'tglitch']
    
        def get_input_data_array(self):
            arrays = []
            for tup in (self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas,
                        self.delta_F0s, self.delta_F1s, self.tglitchs):
                arrays.append(self.get_array_from_tuple(tup))
    
            input_data = []
            for vals in itertools.product(*arrays):
                input_data.append(vals)
    
            self.arrays = arrays
            self.input_data = np.array(input_data)