From 8435f54d19002576ee5174cff7b7310f9eaf3f98 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Wed, 21 Sep 2016 18:40:25 +0200 Subject: [PATCH] Splits up the MCMC classes This makes the MCMCGlitchSearch a subclass of the more general MCMCSearch --- pyfstat.py | 276 +++++++++++++++++++++++++++++++++++-------------- tests/tests.py | 12 +-- 2 files changed, 205 insertions(+), 83 deletions(-) diff --git a/pyfstat.py b/pyfstat.py index adeed2b..59ea288 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -346,15 +346,14 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): return twoFsegA + twoFsegB -class MCMCGlitchSearch(BaseSearchClass): - """ MCMC search using the SemiCoherentGlitchSearch """ +class MCMCSearch(BaseSearchClass): + """ MCMC search using ComputeFstat""" @initializer def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref, tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1, - nglitch=0, theta_initial=None, minCoverFreq=None, + theta_initial=None, minCoverFreq=None, maxCoverFreq=None, scatter_val=1e-4, betas=None, - detector=None, dtglitchmin=20*86400, earth_ephem=None, - sun_ephem=None): + detector=None, earth_ephem=None, sun_ephem=None): """ Parameters label, outdir: str @@ -370,8 +369,6 @@ class MCMCGlitchSearch(BaseSearchClass): Either a dictionary of distribution about which to distribute the initial walkers about, an array (from which the walkers will be scattered by scatter_val, or None in which case the prior is used. - nglitch: int - The number of glitches to allow tref, tstart, tend: int GPS seconds of the reference time, start time and end time nsteps: list (m,) @@ -379,9 +376,6 @@ class MCMCGlitchSearch(BaseSearchClass): give the nburn and nprod of the 'production' run, all entries before are for iterative initialisation steps (usually just one) e.g. [1000, 1000, 500]. - dtglitchmin: int - The minimum duration (in seconds) of a segment between two glitches - or a glitch and the start/end of the data nwalkers, ntemps: int Number of walkers and temperatures minCoverFreq, maxCoverFreq: float @@ -394,12 +388,14 @@ class MCMCGlitchSearch(BaseSearchClass): """ - logging.info(('Set-up MCMC search with {} glitches for model {} on' - ' data {}').format(self.nglitch, self.label, - self.sftlabel)) + logging.info( + 'Set-up MCMC search for model {} on data {}'.format( + self.label, self.sftlabel)) if os.path.isdir(outdir) is False: os.mkdir(outdir) self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label) + self.theta_prior['tstart'] = self.tstart + self.theta_prior['tend'] = self.tend self.unpack_input_theta() self.ndim = len(self.theta_keys) self.sft_filepath = self.sftdir+'/*_'+self.sftlabel+"*sft" @@ -415,65 +411,43 @@ class MCMCGlitchSearch(BaseSearchClass): def inititate_search_object(self): logging.info('Setting up search object') - self.search = SemiCoherentGlitchSearch( - label=self.label, outdir=self.outdir, sftlabel=self.sftlabel, - sftdir=self.sftdir, tref=self.tref, tstart=self.tstart, - tend=self.tend, minCoverFreq=self.minCoverFreq, + self.search = ComputeFstat( + tref=self.tref, sftlabel=self.sftlabel, + sftdir=self.sftdir, minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem, - sun_ephem=self.sun_ephem, detector=self.detector, - nglitch=self.nglitch) + sun_ephem=self.sun_ephem, detector=self.detector) def logp(self, theta_vals, theta_prior, theta_keys, search): - if self.nglitch > 1: - ts = [self.tstart] + theta_vals[-self.nglitch:] + [self.tend] - if np.array_equal(ts, np.sort(ts)) is False: - return -np.inf - if any(np.diff(ts) < self.dtglitchmin): - return -np.inf - - H = [self.Generic_lnprior(**theta_prior[key])(p) for p, key in + H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in zip(theta_vals, theta_keys)] return np.sum(H) def logl(self, theta, search): for j, theta_i in enumerate(self.theta_idxs): self.fixed_theta[theta_i] = theta[j] - FS = search.compute_nglitch_fstat(*self.fixed_theta) + FS = search.run_computefstatistic_single_point(*self.fixed_theta) return FS def unpack_input_theta(self): - glitch_keys = ['delta_F0', 'delta_F1', 'tglitch'] - full_glitch_keys = list(np.array( - [[gk]*self.nglitch for gk in glitch_keys]).flatten()) - full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']+full_glitch_keys + full_theta_keys = ['tstart', 'tend', 'F0', 'F1', 'F2', 'Alpha', + 'Delta'] full_theta_keys_copy = copy.copy(full_theta_keys) - glitch_symbols = ['$\delta f$', '$\delta \dot{f}$', r'$t_{glitch}$'] - full_glitch_symbols = list(np.array( - [[gs]*self.nglitch for gs in glitch_symbols]).flatten()) - full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$', - r'$\delta$'] + full_glitch_symbols) + full_theta_symbols = ['_', '_', '$f$', '$\dot{f}$', '$\ddot{f}$', + r'$\alpha$', r'$\delta$'] self.theta_keys = [] fixed_theta_dict = {} for key, val in self.theta_prior.iteritems(): if type(val) is dict: fixed_theta_dict[key] = 0 - if key in glitch_keys: - for i in range(self.nglitch): - self.theta_keys.append(key) - else: - self.theta_keys.append(key) + self.theta_keys.append(key) elif type(val) in [float, int, np.float64]: fixed_theta_dict[key] = val else: raise ValueError( 'Type {} of {} in theta not recognised'.format( type(val), key)) - if key in glitch_keys: - for i in range(self.nglitch): - full_theta_keys_copy.pop(full_theta_keys_copy.index(key)) - else: - full_theta_keys_copy.pop(full_theta_keys_copy.index(key)) + full_theta_keys_copy.pop(full_theta_keys_copy.index(key)) if len(full_theta_keys_copy) > 0: raise ValueError(('Input dictionary `theta` is missing the' @@ -489,13 +463,6 @@ class MCMCGlitchSearch(BaseSearchClass): self.theta_symbols = [self.theta_symbols[i] for i in idxs] self.theta_keys = [self.theta_keys[i] for i in idxs] - # Correct for number of glitches in the idxs - self.theta_idxs = np.array(self.theta_idxs) - while np.sum(self.theta_idxs[:-1] == self.theta_idxs[1:]) > 0: - for i, idx in enumerate(self.theta_idxs): - if idx in self.theta_idxs[:i]: - self.theta_idxs[i] += 1 - def check_initial_points(self, p0): initial_priors = np.array([ self.logp(p, self.theta_prior, self.theta_keys, self.search) @@ -525,7 +492,8 @@ class MCMCGlitchSearch(BaseSearchClass): logpargs=(self.theta_prior, self.theta_keys, self.search), loglargs=(self.search,), betas=self.betas) - p0 = self.GenerateInitial() + p0 = self.generate_initial_p0() + p0 = self.apply_corrections_to_p0(p0) self.check_initial_points(p0) ninit_steps = len(self.nsteps) - 2 @@ -534,11 +502,12 @@ class MCMCGlitchSearch(BaseSearchClass): j, ninit_steps, n)) sampler.run_mcmc(p0, n) - fig, axes = self.PlotWalkers(sampler, symbols=self.theta_symbols) + fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols) fig.savefig('{}/{}_init_{}_walkers.png'.format( self.outdir, self.label, j)) p0 = self.get_new_p0(sampler, scatter_val=self.scatter_val) + p0 = self.apply_corrections_to_p0(p0) self.check_initial_points(p0) sampler.reset() @@ -548,7 +517,7 @@ class MCMCGlitchSearch(BaseSearchClass): nburn+nprod)) sampler.run_mcmc(p0, nburn+nprod) - fig, axes = self.PlotWalkers(sampler, symbols=self.theta_symbols) + fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols) fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label)) samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim)) @@ -622,14 +591,14 @@ class MCMCGlitchSearch(BaseSearchClass): ax = axes[i][i] xlim = ax.get_xlim() s = samples[:, i] - prior = self.Generic_lnprior(**self.theta_prior[key]) + prior = self.generic_lnprior(**self.theta_prior[key]) x = np.linspace(s.min(), s.max(), 100) ax2 = ax.twinx() ax2.get_yaxis().set_visible(False) ax2.plot(x, [prior(xi) for xi in x], '-r') ax.set_xlim(xlim) - def Generic_lnprior(self, **kwargs): + def generic_lnprior(self, **kwargs): """ Return a lambda function of the pdf Parameters @@ -679,7 +648,7 @@ class MCMCGlitchSearch(BaseSearchClass): logging.info("kwargs:", kwargs) raise ValueError("Print unrecognise distribution") - def GenerateRV(self, **kwargs): + def generate_rv(self, **kwargs): dist_type = kwargs.pop('type') if dist_type == "unif": return np.random.uniform(low=kwargs['lower'], high=kwargs['upper']) @@ -694,8 +663,8 @@ class MCMCGlitchSearch(BaseSearchClass): else: raise ValueError("dist_type {} unknown".format(dist_type)) - def PlotWalkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0, - start=None, stop=None, draw_vline=None): + def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0, + start=None, stop=None, draw_vline=None): """ Plot all the chains from a sampler """ shape = sampler.chain.shape @@ -725,38 +694,35 @@ class MCMCGlitchSearch(BaseSearchClass): return fig, axes - def _generate_scattered_p0(self, p): + def apply_corrections_to_p0(self, p0): + """ Apply any correction to the initial p0 values """ + return p0 + + def generate_scattered_p0(self, p): """ Generate a set of p0s scattered about p """ - p0 = [[p + scatter_val * p * np.random.randn(self.ndim) + p0 = [[p + self.scatter_val * p * np.random.randn(self.ndim) for i in xrange(self.nwalkers)] for j in xrange(self.ntemps)] return p0 - def _sort_p0_times(self, p0): - p0 = np.array(p0) - p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:], axis=2) - return p0 - - def GenerateInitial(self): + def generate_initial_p0(self): """ Generate a set of init vals for the walkers """ if type(self.theta_initial) == dict: - p0 = [[[self.GenerateRV(**self.theta_initial[key]) + p0 = [[[self.generate_rv(**self.theta_initial[key]) for key in self.theta_keys] for i in range(self.nwalkers)] for j in range(self.ntemps)] elif self.theta_initial is None: - p0 = [[[self.GenerateRV(**self.theta_prior[key]) + p0 = [[[self.generate_rv(**self.theta_prior[key]) for key in self.theta_keys] for i in range(self.nwalkers)] for j in range(self.ntemps)] elif len(self.theta_initial) == self.ndim: - p0 = self._generate_scattered_p0(self.theta_initial) + p0 = self.generate_scattered_p0(self.theta_initial) else: raise ValueError('theta_initial not understood') - if self.nglitch > 1: - p0 = self._sort_p0_times(p0) return p0 def get_new_p0(self, sampler, scatter_val=1e-3): @@ -780,8 +746,6 @@ class MCMCGlitchSearch(BaseSearchClass): p = pF[np.nanargmax(lnp)] p0 = self._generate_scattered_p0(p) - if self.nglitch > 1: - p0 = self._sort_p0_times(p0) return p0 def get_save_data_dictionary(self): @@ -923,6 +887,164 @@ class MCMCGlitchSearch(BaseSearchClass): k, d[k], d[k+'_std'])) +class MCMCGlitchSearch(MCMCSearch): + """ MCMC search using the SemiCoherentGlitchSearch """ + @initializer + def __init__(self, label, outdir, sftlabel, sftdir, theta_prior, tref, + tstart, tend, nsteps=[100, 100, 100], nwalkers=100, ntemps=1, + nglitch=0, theta_initial=None, minCoverFreq=None, + maxCoverFreq=None, scatter_val=1e-4, betas=None, + detector=None, dtglitchmin=20*86400, earth_ephem=None, + sun_ephem=None): + """ + Parameters + label, outdir: str + A label and directory to read/write data from/to + sftlabel, sftdir: str + A label and directory in which to find the relevant sft file + theta_prior: dict + Dictionary of priors and fixed values for the search parameters. + For each parameters (key of the dict), if it is to be held fixed + the value should be the constant float, if it is be searched, the + value should be a dictionary of the prior. + theta_initial: dict, array, (None) + Either a dictionary of distribution about which to distribute the + initial walkers about, an array (from which the walkers will be + scattered by scatter_val, or None in which case the prior is used. + nglitch: int + The number of glitches to allow + tref, tstart, tend: int + GPS seconds of the reference time, start time and end time + nsteps: list (m,) + List specifying the number of steps to take, the last two entries + give the nburn and nprod of the 'production' run, all entries + before are for iterative initialisation steps (usually just one) + e.g. [1000, 1000, 500]. + dtglitchmin: int + The minimum duration (in seconds) of a segment between two glitches + or a glitch and the start/end of the data + nwalkers, ntemps: int + Number of walkers and temperatures + minCoverFreq, maxCoverFreq: float + Minimum and maximum instantaneous frequency which will be covered + over the SFT time span as passed to CreateFstatInput + earth_ephem, sun_ephem: str + Paths of the two files containing positions of Earth and Sun, + respectively at evenly spaced times, as passed to CreateFstatInput + If None defaults defined in BaseSearchClass will be used + + """ + + logging.info(('Set-up MCMC glitch search with {} glitches for model {}' + ' on data {}').format(self.nglitch, self.label, + self.sftlabel)) + if os.path.isdir(outdir) is False: + os.mkdir(outdir) + self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label) + self.unpack_input_theta() + self.ndim = len(self.theta_keys) + self.sft_filepath = self.sftdir+'/*_'+self.sftlabel+"*sft" + if earth_ephem is None: + self.earth_ephem = self.earth_ephem_default + if sun_ephem is None: + self.sun_ephem = self.sun_ephem_default + + if args.clean and os.path.isfile(self.pickle_path): + os.rename(self.pickle_path, self.pickle_path+".old") + + self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use() + + def inititate_search_object(self): + logging.info('Setting up search object') + self.search = SemiCoherentGlitchSearch( + label=self.label, outdir=self.outdir, sftlabel=self.sftlabel, + sftdir=self.sftdir, tref=self.tref, tstart=self.tstart, + tend=self.tend, minCoverFreq=self.minCoverFreq, + maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem, + sun_ephem=self.sun_ephem, detector=self.detector, + nglitch=self.nglitch) + + def logp(self, theta_vals, theta_prior, theta_keys, search): + if self.nglitch > 1: + ts = [self.tstart] + theta_vals[-self.nglitch:] + [self.tend] + if np.array_equal(ts, np.sort(ts)) is False: + return -np.inf + if any(np.diff(ts) < self.dtglitchmin): + return -np.inf + + H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in + zip(theta_vals, theta_keys)] + return np.sum(H) + + def logl(self, theta, search): + for j, theta_i in enumerate(self.theta_idxs): + self.fixed_theta[theta_i] = theta[j] + FS = search.compute_nglitch_fstat(*self.fixed_theta) + return FS + + def unpack_input_theta(self): + glitch_keys = ['delta_F0', 'delta_F1', 'tglitch'] + full_glitch_keys = list(np.array( + [[gk]*self.nglitch for gk in glitch_keys]).flatten()) + full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']+full_glitch_keys + full_theta_keys_copy = copy.copy(full_theta_keys) + + glitch_symbols = ['$\delta f$', '$\delta \dot{f}$', r'$t_{glitch}$'] + full_glitch_symbols = list(np.array( + [[gs]*self.nglitch for gs in glitch_symbols]).flatten()) + full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$', + r'$\delta$'] + full_glitch_symbols) + self.theta_keys = [] + fixed_theta_dict = {} + for key, val in self.theta_prior.iteritems(): + if type(val) is dict: + fixed_theta_dict[key] = 0 + if key in glitch_keys: + for i in range(self.nglitch): + self.theta_keys.append(key) + else: + self.theta_keys.append(key) + elif type(val) in [float, int, np.float64]: + fixed_theta_dict[key] = val + else: + raise ValueError( + 'Type {} of {} in theta not recognised'.format( + type(val), key)) + if key in glitch_keys: + for i in range(self.nglitch): + full_theta_keys_copy.pop(full_theta_keys_copy.index(key)) + else: + full_theta_keys_copy.pop(full_theta_keys_copy.index(key)) + + if len(full_theta_keys_copy) > 0: + raise ValueError(('Input dictionary `theta` is missing the' + 'following keys: {}').format( + full_theta_keys_copy)) + + self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys] + self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys] + self.theta_symbols = [full_theta_symbols[i] for i in self.theta_idxs] + + idxs = np.argsort(self.theta_idxs) + self.theta_idxs = [self.theta_idxs[i] for i in idxs] + self.theta_symbols = [self.theta_symbols[i] for i in idxs] + self.theta_keys = [self.theta_keys[i] for i in idxs] + + # Correct for number of glitches in the idxs + self.theta_idxs = np.array(self.theta_idxs) + while np.sum(self.theta_idxs[:-1] == self.theta_idxs[1:]) > 0: + for i, idx in enumerate(self.theta_idxs): + if idx in self.theta_idxs[:i]: + self.theta_idxs[i] += 1 + + def apply_corrections_to_p0(self, p0): + p0 = np.array(p0) + if self.nglitch > 1: + p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:], + axis=2) + return p0 + + class GridGlitchSearch(BaseSearchClass): """ Gridded search using the SemiCoherentGlitchSearch """ @initializer diff --git a/tests/tests.py b/tests/tests.py index 98a3c3c..176b75f 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -137,7 +137,7 @@ class TestSemiCoherentGlitchSearch(unittest.TestCase): self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3) -class TestMCMCGlitchSearch(unittest.TestCase): +class TestMCMCSearch(unittest.TestCase): label = "MCMCTest" outdir = 'TestData' @@ -165,13 +165,12 @@ class TestMCMCGlitchSearch(unittest.TestCase): Writer.make_data() predicted_FS = Writer.predict_fstat() - theta = {'delta_F0': 0, 'delta_F1': 0, 'tglitch': tend, - 'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)}, + theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-9*F0)}, 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-9*F1)}, 'F2': F2, 'Alpha': Alpha, 'Delta': Delta} - search = pyfstat.MCMCGlitchSearch( - label=self.label, outdir=self.outdir, theta=theta, tref=tref, + search = pyfstat.MCMCSearch( + label=self.label, outdir=self.outdir, theta_prior=theta, tref=tref, sftlabel=self.label, sftdir=self.outdir, tstart=tstart, tend=tend, nsteps=[100, 100], nwalkers=100, ntemps=1) @@ -181,7 +180,8 @@ class TestMCMCGlitchSearch(unittest.TestCase): print('Predicted twoF is {} while recovered is {}'.format( predicted_FS, FS)) - self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3) + self.assertTrue( + FS > predicted_FS or np.abs((FS-predicted_FS))/predicted_FS < 0.3) if __name__ == '__main__': -- GitLab