diff --git a/pyfstat/__init__.py b/pyfstat/__init__.py index d1c4da564ee45d48d0915b6f32c78e553a1c5897..139e97b09a13c6d072975a5e0223dfc824e96d81 100644 --- a/pyfstat/__init__.py +++ b/pyfstat/__init__.py @@ -1,7 +1,6 @@ -from __future__ import division +from __future__ import division as _division -from .core import BaseSearchClass, ComputeFstat, Writer -from .mcmc_based_searches import * -from .grid_based_searches import * -from .helper_functions import texify_float +from .core import BaseSearchClass, ComputeFstat, Writer, SemiCoherentSearch, SemiCoherentGlitchSearch +from .mcmc_based_searches import MCMCSearch, MCMCGlitchSearch, MCMCSemiCoherentSearch, MCMCFollowUpSearch, MCMCTransientSearch +from .grid_based_searches import GridSearch, GridUniformPriorSearch, GridGlitchSearch diff --git a/pyfstat/core.py b/pyfstat/core.py index a4784bf8043430d345a35f2ae10ad7e9dfd3a8c1..a8c713c402a35335051a893927627bec7b6ea615 100755 --- a/pyfstat/core.py +++ b/pyfstat/core.py @@ -37,7 +37,7 @@ class BaseSearchClass(object): earth_ephem_default = earth_ephem sun_ephem_default = sun_ephem - def add_log_file(self): + def _add_log_file(self): """ Log output to a file, requires class to have outdir and label """ logfilename = '{}/{}.log'.format(self.outdir, self.label) fh = logging.FileHandler(logfilename) @@ -47,7 +47,7 @@ class BaseSearchClass(object): datefmt='%y-%m-%d %H:%M')) logging.getLogger().addHandler(fh) - def shift_matrix(self, n, dT): + def _shift_matrix(self, n, dT): """ Generate the shift matrix Parameters @@ -78,7 +78,7 @@ class BaseSearchClass(object): m[i, j] = float(dT)**(j-i) / factorial(j-i) return m - def shift_coefficients(self, theta, dT): + def _shift_coefficients(self, theta, dT): """ Shift a set of coefficients by dT Parameters @@ -96,30 +96,30 @@ class BaseSearchClass(object): """ n = len(theta) - m = self.shift_matrix(n, dT) + m = self._shift_matrix(n, dT) return np.dot(m, theta) - def calculate_thetas(self, theta, delta_thetas, tbounds, theta0_idx=0): + def _calculate_thetas(self, theta, delta_thetas, tbounds, theta0_idx=0): """ Calculates the set of coefficients for the post-glitch signal """ thetas = [theta] for i, dt in enumerate(delta_thetas): if i < theta0_idx: - pre_theta_at_ith_glitch = self.shift_coefficients( + pre_theta_at_ith_glitch = self._shift_coefficients( thetas[0], tbounds[i+1] - self.tref) post_theta_at_ith_glitch = pre_theta_at_ith_glitch - dt - thetas.insert(0, self.shift_coefficients( + thetas.insert(0, self._shift_coefficients( post_theta_at_ith_glitch, self.tref - tbounds[i+1])) elif i >= theta0_idx: - pre_theta_at_ith_glitch = self.shift_coefficients( + pre_theta_at_ith_glitch = self._shift_coefficients( thetas[i], tbounds[i+1] - self.tref) post_theta_at_ith_glitch = pre_theta_at_ith_glitch + dt - thetas.append(self.shift_coefficients( + thetas.append(self._shift_coefficients( post_theta_at_ith_glitch, self.tref - tbounds[i+1])) self.thetas_at_tref = thetas return thetas - def generate_loudest(self): + def _generate_loudest(self): params = read_par(self.label, self.outdir) for key in ['Alpha', 'Delta', 'F0', 'F1']: if key not in params: @@ -133,7 +133,7 @@ class BaseSearchClass(object): self.maxStartTime) subprocess.call([cmd], shell=True) - def get_list_of_matching_sfts(self): + def _get_list_of_matching_sfts(self): matches = [glob.glob(p) for p in self.sftfilepath] matches = [item for sublist in matches for item in sublist] if len(matches) > 0: @@ -685,7 +685,7 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): delta_thetas = np.atleast_2d( np.array([delta_phi, delta_F0s, delta_F1s, delta_F2]).T) - thetas = self.calculate_thetas(theta, delta_thetas, tboundaries, + thetas = self._calculate_thetas(theta, delta_thetas, tboundaries, theta0_idx=self.theta0_idx) twoFSum = 0 @@ -713,9 +713,9 @@ class SemiCoherentGlitchSearch(BaseSearchClass, ComputeFstat): delta_theta = [delta_F0, delta_F1, 0] tref = self.tref - theta_at_glitch = self.shift_coefficients(theta, tglitch - tref) + theta_at_glitch = self._shift_coefficients(theta, tglitch - tref) theta_post_glitch_at_glitch = theta_at_glitch + delta_theta - theta_post_glitch = self.shift_coefficients( + theta_post_glitch = self._shift_coefficients( theta_post_glitch_at_glitch, tref - tglitch) twoFsegA = self.run_computefstatistic_single_point( @@ -849,7 +849,7 @@ transientTauDays={:1.3f}\n""") """ - thetas = self.calculate_thetas(self.theta, self.delta_thetas, + thetas = self._calculate_thetas(self.theta, self.delta_thetas, self.tbounds) content = '' diff --git a/pyfstat/helper_functions.py b/pyfstat/helper_functions.py index c153915b218b1c5e82572b2d4c57dcaaf1d059f2..41944e0f436440511cf6ce7886b94f5ca5caa752 100644 --- a/pyfstat/helper_functions.py +++ b/pyfstat/helper_functions.py @@ -65,6 +65,7 @@ def set_up_command_line_arguments(): def set_up_ephemeris_configuration(): + """ Returns the earth_ephem and sun_ephem """ config_file = os.path.expanduser('~')+'/.pyfstat.conf' if os.path.isfile(config_file): d = {} diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index 9cc7083f8a7cb7b500534aa10a283a9358fa11d5..aae72a68c641819edf675eb56a8531f5c5c444d3 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -44,7 +44,8 @@ class MCMCSearch(core.BaseSearchClass): label, outdir: str A label and directory to read/write data from/to sftfilepath: str - File patern to match SFTs + Pattern to match SFTs using wildcards (*?) and ranges [0-9]; + mutiple patterns can be given separated by colons. theta_prior: dict Dictionary of priors and fixed values for the search parameters. For each parameters (key of the dict), if it is to be held fixed @@ -84,12 +85,12 @@ class MCMCSearch(core.BaseSearchClass): if os.path.isdir(outdir) is False: os.mkdir(outdir) - self.add_log_file() + self._add_log_file() logging.info( 'Set-up MCMC search for model {} on data {}'.format( self.label, self.sftfilepath)) self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label) - self.unpack_input_theta() + self._unpack_input_theta() self.ndim = len(self.theta_keys) if self.log10temperature_min: self.betas = np.logspace(0, self.log10temperature_min, self.ntemps) @@ -104,9 +105,9 @@ class MCMCSearch(core.BaseSearchClass): if args.clean and os.path.isfile(self.pickle_path): os.rename(self.pickle_path, self.pickle_path+".old") - self.log_input() + self._log_input() - def log_input(self): + def _log_input(self): logging.info('theta_prior = {}'.format(self.theta_prior)) logging.info('nwalkers={}'.format(self.nwalkers)) logging.info('scatter_val = {}'.format(self.scatter_val)) @@ -115,7 +116,7 @@ class MCMCSearch(core.BaseSearchClass): logging.info('log10temperature_min = {}'.format( self.log10temperature_min)) - def initiate_search_object(self): + def _initiate_search_object(self): logging.info('Setting up search object') self.search = core.ComputeFstat( tref=self.tref, sftfilepath=self.sftfilepath, @@ -127,7 +128,7 @@ class MCMCSearch(core.BaseSearchClass): assumeSqrtSX=self.assumeSqrtSX) def logp(self, theta_vals, theta_prior, theta_keys, search): - H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in + H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in zip(theta_vals, theta_keys)] return np.sum(H) @@ -138,7 +139,7 @@ class MCMCSearch(core.BaseSearchClass): *self.fixed_theta) return FS - def unpack_input_theta(self): + def _unpack_input_theta(self): full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta'] if self.binary: full_theta_keys += [ @@ -179,7 +180,7 @@ class MCMCSearch(core.BaseSearchClass): self.theta_symbols = [self.theta_symbols[i] for i in idxs] self.theta_keys = [self.theta_keys[i] for i in idxs] - def check_initial_points(self, p0): + def _check_initial_points(self, p0): for nt in range(self.ntemps): logging.info('Checking temperature {} chains'.format(nt)) initial_priors = np.array([ @@ -193,10 +194,10 @@ class MCMCSearch(core.BaseSearchClass): .format(len(initial_priors), number_of_initial_out_of_bounds)) - p0 = self.generate_new_p0_to_fix_initial_points( + p0 = self._generate_new_p0_to_fix_initial_points( p0, nt, initial_priors) - def generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors): + def _generate_new_p0_to_fix_initial_points(self, p0, nt, initial_priors): logging.info('Attempting to correct intial values') idxs = np.arange(self.nwalkers)[initial_priors == -np.inf] count = 0 @@ -217,7 +218,7 @@ class MCMCSearch(core.BaseSearchClass): return p0 - def OLD_run_sampler_with_progress_bar(self, sampler, ns, p0): + def _OLD_run_sampler_with_progress_bar(self, sampler, ns, p0): for result in tqdm(sampler.sample(p0, iterations=ns), total=ns): pass return sampler @@ -271,7 +272,7 @@ class MCMCSearch(core.BaseSearchClass): self.convergence_number = 0 self.convergence_plot_upper_lim = convergence_plot_upper_lim - def get_convergence_statistic(self, i, sampler): + def _get_convergence_statistic(self, i, sampler): s = sampler.chain[0, :, i-self.convergence_length+1:i+1, :] within_std = np.mean(np.var(s, axis=1), axis=0) per_walker_mean = np.mean(s, axis=1) @@ -286,25 +287,25 @@ class MCMCSearch(core.BaseSearchClass): self.convergence_diagnosticx.append(i - self.convergence_length/2) return c - def burnin_convergence_test(self, i, sampler, nburn): + def _burnin_convergence_test(self, i, sampler, nburn): if i < self.convergence_burnin_fraction*nburn: return False if np.mod(i+1, self.convergence_period) != 0: return False - c = self.get_convergence_statistic(i, sampler) + c = self._get_convergence_statistic(i, sampler) if np.all(c < self.convergence_threshold): self.convergence_number += 1 else: self.convergence_number = 0 return self.convergence_number > self.convergence_threshold_number - def prod_convergence_test(self, i, sampler, nburn): + def _prod_convergence_test(self, i, sampler, nburn): testA = i > nburn + self.convergence_length testB = np.mod(i+1, self.convergence_period) == 0 if testA and testB: - self.get_convergence_statistic(i, sampler) + self._get_convergence_statistic(i, sampler) - def check_production_convergence(self, k): + def _check_production_convergence(self, k): bools = np.any( np.array(self.convergence_diagnostic)[k:, :] > self.convergence_prod_threshold, axis=1) @@ -313,13 +314,13 @@ class MCMCSearch(core.BaseSearchClass): '{} convergence tests in the production run of {} failed' .format(np.sum(bools), len(bools))) - def run_sampler(self, sampler, p0, nprod=0, nburn=0): + def _run_sampler(self, sampler, p0, nprod=0, nburn=0): if hasattr(self, 'convergence_period'): logging.info('Running {} burn-in steps with convergence testing' .format(nburn)) iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn) for i, output in enumerate(iterator): - if self.burnin_convergence_test(i, sampler, nburn): + if self._burnin_convergence_test(i, sampler, nburn): logging.info( 'Converged at {} before max number {} of steps reached' .format(i, nburn)) @@ -331,9 +332,9 @@ class MCMCSearch(core.BaseSearchClass): k = len(self.convergence_diagnostic) for result in tqdm(sampler.sample(output[0], iterations=nprod), total=nprod): - self.prod_convergence_test(j, sampler, nburn) + self._prod_convergence_test(j, sampler, nburn) j += 1 - self.check_production_convergence(k) + self._check_production_convergence(k) return sampler else: for result in tqdm(sampler.sample(p0, iterations=nburn+nprod), @@ -342,50 +343,51 @@ class MCMCSearch(core.BaseSearchClass): return sampler def run(self, proposal_scale_factor=2, create_plots=True, **kwargs): + """ Run the MCMC simulatation """ - self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use() + self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use() if self.old_data_is_okay_to_use is True: logging.warning('Using saved data from {}'.format( self.pickle_path)) - d = self.get_saved_data() + d = self.get_saved_data_dictionary() self.sampler = d['sampler'] self.samples = d['samples'] self.lnprobs = d['lnprobs'] self.lnlikes = d['lnlikes'] return - self.initiate_search_object() + self._initiate_search_object() sampler = emcee.PTSampler( self.ntemps, self.nwalkers, self.ndim, self.logl, self.logp, logpargs=(self.theta_prior, self.theta_keys, self.search), loglargs=(self.search,), betas=self.betas, a=proposal_scale_factor) - p0 = self.generate_initial_p0() - p0 = self.apply_corrections_to_p0(p0) - self.check_initial_points(p0) + p0 = self._generate_initial_p0() + p0 = self._apply_corrections_to_p0(p0) + self._check_initial_points(p0) ninit_steps = len(self.nsteps) - 2 for j, n in enumerate(self.nsteps[:-2]): logging.info('Running {}/{} initialisation with {} steps'.format( j, ninit_steps, n)) - sampler = self.run_sampler(sampler, p0, nburn=n) + sampler = self._run_sampler(sampler, p0, nburn=n) logging.info("Mean acceptance fraction: {}" .format(np.mean(sampler.acceptance_fraction, axis=1))) if self.ntemps > 1: logging.info("Tswap acceptance fraction: {}" .format(sampler.tswap_acceptance_fraction)) if create_plots: - fig, axes = self.plot_walkers(sampler, + fig, axes = self._plot_walkers(sampler, symbols=self.theta_symbols, **kwargs) fig.tight_layout() fig.savefig('{}/{}_init_{}_walkers.png'.format( self.outdir, self.label, j), dpi=400) - p0 = self.get_new_p0(sampler) - p0 = self.apply_corrections_to_p0(p0) - self.check_initial_points(p0) + p0 = self._get_new_p0(sampler) + p0 = self._apply_corrections_to_p0(p0) + self._check_initial_points(p0) sampler.reset() if len(self.nsteps) > 1: @@ -395,7 +397,7 @@ class MCMCSearch(core.BaseSearchClass): nprod = self.nsteps[-1] logging.info('Running final burn and prod with {} steps'.format( nburn+nprod)) - sampler = self.run_sampler(sampler, p0, nburn=nburn, nprod=nprod) + sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod) logging.info("Mean acceptance fraction: {}" .format(np.mean(sampler.acceptance_fraction, axis=1))) if self.ntemps > 1: @@ -403,7 +405,7 @@ class MCMCSearch(core.BaseSearchClass): .format(sampler.tswap_acceptance_fraction)) if create_plots: - fig, axes = self.plot_walkers(sampler, symbols=self.theta_symbols, + fig, axes = self._plot_walkers(sampler, symbols=self.theta_symbols, nprod=nprod, **kwargs) fig.tight_layout() fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label), @@ -416,9 +418,9 @@ class MCMCSearch(core.BaseSearchClass): self.samples = samples self.lnprobs = lnprobs self.lnlikes = lnlikes - self.save_data(sampler, samples, lnprobs, lnlikes) + self._save_data(sampler, samples, lnprobs, lnlikes) - def get_rescale_multiplier_for_key(self, key): + def _get_rescale_multiplier_for_key(self, key): """ Get the rescale multiplier from the rescale_dictionary Can either be a float, a string (in which case it is interpretted as @@ -443,7 +445,7 @@ class MCMCSearch(core.BaseSearchClass): multiplier = 1 return multiplier - def get_rescale_subtractor_for_key(self, key): + def _get_rescale_subtractor_for_key(self, key): """ Get the rescale subtractor from the rescale_dictionary Can either be a float, a string (in which case it is interpretted as @@ -468,21 +470,21 @@ class MCMCSearch(core.BaseSearchClass): subtractor = 0 return subtractor - def scale_samples(self, samples, theta_keys): + def _scale_samples(self, samples, theta_keys): """ Scale the samples using the rescale_dictionary """ for key in theta_keys: if key in self.rescale_dictionary: idx = theta_keys.index(key) s = samples[:, idx] - subtractor = self.get_rescale_subtractor_for_key(key) + subtractor = self._get_rescale_subtractor_for_key(key) s = s - subtractor - multiplier = self.get_rescale_multiplier_for_key(key) + multiplier = self._get_rescale_multiplier_for_key(key) s *= multiplier samples[:, idx] = s return samples - def get_labels(self): + def _get_labels(self): """ Combine the units, symbols and rescaling to give labels """ labels = [] @@ -503,9 +505,38 @@ class MCMCSearch(core.BaseSearchClass): labels.append(label) return labels - def plot_corner(self, figsize=(7, 7), tglitch_ratio=False, - add_prior=False, nstds=None, label_offset=0.4, - dpi=300, rc_context={}, **kwargs): + def plot_corner(self, figsize=(7, 7), add_prior=False, nstds=None, + label_offset=0.4, dpi=300, rc_context={}, + tglitch_ratio=False, **kwargs): + """ Generate a corner plot of the posterior + + Using the `corner` package (https://pypi.python.org/pypi/corner/), + generate estimates of the posterior from the production samples. + + Parameters + ---------- + figsize: tuple (7, 7) + Figure size in inches (passed to plt.subplots) + add_prior: bool + If true, plot the prior as a red line + nstds: float + The number of standard deviations to plot centered on the mean + label_offset: float + Offset the labels from the plot: useful to precent overlapping the + tick labels with the axis labels + dpi: int + Passed to plt.savefig + rc_context: dict + Dictionary of rc values to set while generating the figure (see + matplotlib rc for more details) + tglitch_ratio: bool + If true, and tglitch is a parameter, plot posteriors as the + fractional time at which the glitch occurs instead of the actual + time + + Note: kwargs are passed on to corner.coner + + """ if self.ndim < 2: with plt.rc_context(rc_context): @@ -522,9 +553,9 @@ class MCMCSearch(core.BaseSearchClass): figsize=figsize) samples_plt = copy.copy(self.samples) - labels = self.get_labels() + labels = self._get_labels() - samples_plt = self.scale_samples(samples_plt, self.theta_keys) + samples_plt = self._scale_samples(samples_plt, self.theta_keys) if tglitch_ratio: for j, k in enumerate(self.theta_keys): @@ -571,20 +602,20 @@ class MCMCSearch(core.BaseSearchClass): fig.subplots_adjust(hspace=0.05, wspace=0.05) if add_prior: - self.add_prior_to_corner(axes, self.samples) + self._add_prior_to_corner(axes, self.samples) fig_triangle.savefig('{}/{}_corner.png'.format( self.outdir, self.label), dpi=dpi) - def add_prior_to_corner(self, axes, samples): + def _add_prior_to_corner(self, axes, samples): for i, key in enumerate(self.theta_keys): ax = axes[i][i] xlim = ax.get_xlim() s = samples[:, i] - prior = self.generic_lnprior(**self.theta_prior[key]) + prior = self._generic_lnprior(**self.theta_prior[key]) x = np.linspace(s.min(), s.max(), 100) - multiplier = self.get_rescale_multiplier_for_key(key) - subtractor = self.get_rescale_subtractor_for_key(key) + multiplier = self._get_rescale_multiplier_for_key(key) + subtractor = self._get_rescale_subtractor_for_key(key) ax2 = ax.twinx() ax2.get_yaxis().set_visible(False) ax2.plot((x-subtractor)*multiplier, [prior(xi) for xi in x], '-r') @@ -598,7 +629,7 @@ class MCMCSearch(core.BaseSearchClass): for i, (ax, key) in enumerate(zip(axes, self.theta_keys)): prior_dict = self.theta_prior[key] - prior_func = self.generic_lnprior(**prior_dict) + prior_func = self._generic_lnprior(**prior_dict) if prior_dict['type'] == 'unif': x = np.linspace(prior_dict['lower'], prior_dict['upper'], N) prior = prior_func(x) @@ -643,13 +674,17 @@ class MCMCSearch(core.BaseSearchClass): self.outdir, self.label)) def plot_cumulative_max(self, **kwargs): + """ Plot the cumulative twoF for the maximum posterior estimate + + See the pyfstat.core.plot_twoF_cumulative function for further details + """ d, maxtwoF = self.get_max_twoF() for key, val in self.theta_prior.iteritems(): if key not in d: d[key] = val if hasattr(self, 'search') is False: - self.initiate_search_object() + self._initiate_search_object() if self.binary is False: self.search.plot_twoF_cumulative( self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'], @@ -663,7 +698,7 @@ class MCMCSearch(core.BaseSearchClass): period=d['period'], ecc=d['ecc'], argp=d['argp'], tp=d['argp'], tstart=self.minStartTime, tend=self.maxStartTime, **kwargs) - def generic_lnprior(self, **kwargs): + def _generic_lnprior(self, **kwargs): """ Return a lambda function of the pdf Parameters @@ -715,7 +750,7 @@ class MCMCSearch(core.BaseSearchClass): logging.info("kwargs:", kwargs) raise ValueError("Print unrecognise distribution") - def generate_rv(self, **kwargs): + def _generate_rv(self, **kwargs): dist_type = kwargs.pop('type') if dist_type == "unif": return np.random.uniform(low=kwargs['lower'], high=kwargs['upper']) @@ -733,10 +768,10 @@ class MCMCSearch(core.BaseSearchClass): else: raise ValueError("dist_type {} unknown".format(dist_type)) - def plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", temp=0, - lw=0.1, nprod=0, add_det_stat_burnin=False, - fig=None, axes=None, xoffset=0, plot_det_stat=False, - context='classic', subtractions=None, labelpad=0.05): + def _plot_walkers(self, sampler, symbols=None, alpha=0.4, color="k", + temp=0, lw=0.1, nprod=0, add_det_stat_burnin=False, + fig=None, axes=None, xoffset=0, plot_det_stat=False, + context='classic', subtractions=None, labelpad=0.05): """ Plot all the chains from a sampler """ if np.ndim(axes) > 1: @@ -865,48 +900,48 @@ class MCMCSearch(core.BaseSearchClass): axes[-2].set_xlabel(r'$\textrm{Number of steps}$', labelpad=0.2) return fig, axes - def apply_corrections_to_p0(self, p0): + def _apply_corrections_to_p0(self, p0): """ Apply any correction to the initial p0 values """ return p0 - def generate_scattered_p0(self, p): + def _generate_scattered_p0(self, p): """ Generate a set of p0s scattered about p """ p0 = [[p + self.scatter_val * p * np.random.randn(self.ndim) for i in xrange(self.nwalkers)] for j in xrange(self.ntemps)] return p0 - def generate_initial_p0(self): + def _generate_initial_p0(self): """ Generate a set of init vals for the walkers """ if type(self.theta_initial) == dict: logging.info('Generate initial values from initial dictionary') if hasattr(self, 'nglitch') and self.nglitch > 1: raise ValueError('Initial dict not implemented for nglitch>1') - p0 = [[[self.generate_rv(**self.theta_initial[key]) + p0 = [[[self._generate_rv(**self.theta_initial[key]) for key in self.theta_keys] for i in range(self.nwalkers)] for j in range(self.ntemps)] elif type(self.theta_initial) == list: logging.info('Generate initial values from list of theta_initial') - p0 = [[[self.generate_rv(**val) + p0 = [[[self._generate_rv(**val) for val in self.theta_initial] for i in range(self.nwalkers)] for j in range(self.ntemps)] elif self.theta_initial is None: logging.info('Generate initial values from prior dictionary') - p0 = [[[self.generate_rv(**self.theta_prior[key]) + p0 = [[[self._generate_rv(**self.theta_prior[key]) for key in self.theta_keys] for i in range(self.nwalkers)] for j in range(self.ntemps)] elif len(self.theta_initial) == self.ndim: - p0 = self.generate_scattered_p0(self.theta_initial) + p0 = self._generate_scattered_p0(self.theta_initial) else: raise ValueError('theta_initial not understood') return p0 - def get_new_p0(self, sampler): + def _get_new_p0(self, sampler): """ Returns new initial positions for walkers are burn0 stage This returns new positions for all walkers by scattering points about @@ -936,7 +971,7 @@ class MCMCSearch(core.BaseSearchClass): lnp_finite[np.isinf(lnp)] = np.nan idx = np.unravel_index(np.nanargmax(lnp_finite), lnp_finite.shape) p = pF[idx] - p0 = self.generate_scattered_p0(p) + p0 = self._generate_scattered_p0(p) self.search.BSGL = False twoF = self.logl(p, self.search) @@ -948,7 +983,7 @@ class MCMCSearch(core.BaseSearchClass): return p0 - def get_save_data_dictionary(self): + def _get_data_dictionary_to_save(self): d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers, ntemps=self.ntemps, theta_keys=self.theta_keys, theta_prior=self.theta_prior, scatter_val=self.scatter_val, @@ -956,8 +991,8 @@ class MCMCSearch(core.BaseSearchClass): BSGL=self.BSGL) return d - def save_data(self, sampler, samples, lnprobs, lnlikes): - d = self.get_save_data_dictionary() + def _save_data(self, sampler, samples, lnprobs, lnlikes): + d = self._get_data_dictionary_to_save() d['sampler'] = sampler d['samples'] = samples d['lnprobs'] = lnprobs @@ -970,12 +1005,13 @@ class MCMCSearch(core.BaseSearchClass): with open(self.pickle_path, "wb") as File: pickle.dump(d, File) - def get_saved_data(self): + def get_saved_data_dictionary(self): + """ Returns dictionary of the data saved as pickle """ with open(self.pickle_path, "r") as File: d = pickle.load(File) return d - def check_old_data_is_okay_to_use(self): + def _check_old_data_is_okay_to_use(self): if args.use_old_data: logging.info("Forcing use of old data") return True @@ -986,13 +1022,13 @@ class MCMCSearch(core.BaseSearchClass): if self.sftfilepath is not None: oldest_sft = min([os.path.getmtime(f) for f in - self.get_list_of_matching_sfts()]) + self._get_list_of_matching_sfts()]) if os.path.getmtime(self.pickle_path) < oldest_sft: logging.info('Pickled data outdates sft files') return False - old_d = self.get_saved_data().copy() - new_d = self.get_save_data_dictionary().copy() + old_d = self.get_saved_data_dictionary().copy() + new_d = self._get_data_dictionary_to_save().copy() old_d.pop('samples') old_d.pop('sampler') @@ -1043,7 +1079,7 @@ class MCMCSearch(core.BaseSearchClass): if self.BSGL: if hasattr(self, 'search') is False: - self.initiate_search_object() + self._initiate_search_object() p = self.samples[jmax] self.search.BSGL = False maxtwoF = self.logl(p, self.search) @@ -1089,6 +1125,13 @@ class MCMCSearch(core.BaseSearchClass): return d def check_if_samples_are_railing(self, threshold=0.01): + """ Returns a boolean estimate of if the samples are railing + + Parameters + ---------- + threshold: float [0, 1] + Fraction of the uniform prior to test (at upper and lower bound) + """ return_flag = False for s, k in zip(self.samples.T, self.theta_keys): prior = self.theta_prior[k] @@ -1160,6 +1203,7 @@ class MCMCSearch(core.BaseSearchClass): f.write("\n\end{tabular}\n") def print_summary(self): + """ Prints a summary of the max twoF found to the terminal """ max_twoFd, max_twoF = self.get_max_twoF() median_std_d = self.get_median_stds() logging.info('Summary:') @@ -1175,18 +1219,18 @@ class MCMCSearch(core.BaseSearchClass): k, median_std_d[k], median_std_d[k+'_std'])) logging.info('\n') - def CF_twoFmax(self, theta, twoFmax, ntrials): + def _CF_twoFmax(self, theta, twoFmax, ntrials): Fmax = twoFmax/2.0 return (np.exp(1j*theta*twoFmax)*ntrials/2.0 * Fmax*np.exp(-Fmax)*(1-(1+Fmax)*np.exp(-Fmax))**(ntrials-1)) - def pdf_twoFhat(self, twoFhat, nglitch, ntrials, twoFmax=100, dtwoF=0.1): + def _pdf_twoFhat(self, twoFhat, nglitch, ntrials, twoFmax=100, dtwoF=0.1): if np.ndim(ntrials) == 0: ntrials = np.zeros(nglitch+1) + ntrials twoFmax_int = np.arange(0, twoFmax, dtwoF) theta_int = np.arange(-1/dtwoF, 1./dtwoF, 1./twoFmax) CF_twoFmax_theta = np.array( - [[np.trapz(self.CF_twoFmax(t, twoFmax_int, ntrial), twoFmax_int) + [[np.trapz(self._CF_twoFmax(t, twoFmax_int, ntrial), twoFmax_int) for t in theta_int] for ntrial in ntrials]) CF_twoFhat_theta = np.prod(CF_twoFmax_theta, axis=0) @@ -1195,7 +1239,7 @@ class MCMCSearch(core.BaseSearchClass): * CF_twoFhat_theta, theta_int) for twoFhat_val in twoFhat]) return pdf.real - def p_val_twoFhat(self, twoFhat, ntrials, twoFhatmax=500, Npoints=1000): + def _p_val_twoFhat(self, twoFhat, ntrials, twoFhatmax=500, Npoints=1000): """ Caluculate the p-value for the given twoFhat in Gaussian noise Parameters @@ -1206,7 +1250,7 @@ class MCMCSearch(core.BaseSearchClass): The number of trials for each glitch+1 """ twoFhats = np.linspace(twoFhat, twoFhatmax, Npoints) - pdf = self.pdf_twoFhat(twoFhats, self.nglitch, ntrials) + pdf = self._pdf_twoFhat(twoFhats, self.nglitch, ntrials) return np.trapz(pdf, twoFhats) def get_p_value(self, delta_F0, time_trials=0): @@ -1220,11 +1264,12 @@ class MCMCSearch(core.BaseSearchClass): tboundaries = [self.minStartTime] + tglitches + [self.maxStartTime] deltaTs = np.diff(tboundaries) ntrials = [time_trials + delta_F0 * dT for dT in deltaTs] - p_val = self.p_val_twoFhat(max_twoF, ntrials) + p_val = self._p_val_twoFhat(max_twoF, ntrials) print('p-value = {}'.format(p_val)) return p_val def get_evidence(self): + """ Get the log10 evidence and error estimate """ fburnin = float(self.nsteps[-2])/np.sum(self.nsteps[-2:]) lnev, lnev_err = self.sampler.thermodynamic_integration_log_evidence( fburnin=fburnin) @@ -1233,7 +1278,7 @@ class MCMCSearch(core.BaseSearchClass): log10evidence_err = lnev_err/np.log(10) return log10evidence, log10evidence_err - def compute_evidence_long(self): + def _compute_evidence_long(self): """ Computes the evidence/marginal likelihood for the model """ betas = self.betas alllnlikes = self.sampler.lnlikelihood[:, :, self.nsteps[-2]:] @@ -1358,12 +1403,12 @@ class MCMCGlitchSearch(MCMCSearch): if os.path.isdir(outdir) is False: os.mkdir(outdir) - self.add_log_file() + self._add_log_file() logging.info(('Set-up MCMC glitch search with {} glitches for model {}' ' on data {}').format(self.nglitch, self.label, self.sftfilepath)) self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label) - self.unpack_input_theta() + self._unpack_input_theta() self.ndim = len(self.theta_keys) if self.log10temperature_min: self.betas = np.logspace(0, self.log10temperature_min, self.ntemps) @@ -1377,10 +1422,10 @@ class MCMCGlitchSearch(MCMCSearch): if args.clean and os.path.isfile(self.pickle_path): os.rename(self.pickle_path, self.pickle_path+".old") - self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use() - self.log_input() + self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use() + self._log_input() - def initiate_search_object(self): + def _initiate_search_object(self): logging.info('Setting up search object') self.search = core.SemiCoherentGlitchSearch( label=self.label, outdir=self.outdir, sftfilepath=self.sftfilepath, @@ -1399,7 +1444,7 @@ class MCMCGlitchSearch(MCMCSearch): if any(np.diff(ts) < self.dtglitchmin): return -np.inf - H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in + H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in zip(theta_vals, theta_keys)] return np.sum(H) @@ -1415,7 +1460,7 @@ class MCMCGlitchSearch(MCMCSearch): FS = search.compute_nglitch_fstat(*self.fixed_theta) return FS - def unpack_input_theta(self): + def _unpack_input_theta(self): glitch_keys = ['delta_F0', 'delta_F1', 'tglitch'] full_glitch_keys = list(np.array( [[gk]*self.nglitch for gk in glitch_keys]).flatten()) @@ -1478,7 +1523,7 @@ class MCMCGlitchSearch(MCMCSearch): if idx in self.theta_idxs[:i]: self.theta_idxs[i] += 1 - def get_save_data_dictionary(self): + def _get_data_dictionary_to_save(self): d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers, ntemps=self.ntemps, theta_keys=self.theta_keys, theta_prior=self.theta_prior, scatter_val=self.scatter_val, @@ -1486,7 +1531,7 @@ class MCMCGlitchSearch(MCMCSearch): theta0_idx=self.theta0_idx, BSGL=self.BSGL) return d - def apply_corrections_to_p0(self, p0): + def _apply_corrections_to_p0(self, p0): p0 = np.array(p0) if self.nglitch > 1: p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:], @@ -1559,12 +1604,12 @@ class MCMCSemiCoherentSearch(MCMCSearch): if os.path.isdir(outdir) is False: os.mkdir(outdir) - self.add_log_file() + self._add_log_file() logging.info(('Set-up MCMC semi-coherent search for model {} on data' '{}').format( self.label, self.sftfilepath)) self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label) - self.unpack_input_theta() + self._unpack_input_theta() self.ndim = len(self.theta_keys) if self.log10temperature_min: self.betas = np.logspace(0, self.log10temperature_min, self.ntemps) @@ -1578,9 +1623,9 @@ class MCMCSemiCoherentSearch(MCMCSearch): if args.clean and os.path.isfile(self.pickle_path): os.rename(self.pickle_path, self.pickle_path+".old") - self.log_input() + self._log_input() - def get_save_data_dictionary(self): + def _get_data_dictionary_to_save(self): d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers, ntemps=self.ntemps, theta_keys=self.theta_keys, theta_prior=self.theta_prior, scatter_val=self.scatter_val, @@ -1588,7 +1633,7 @@ class MCMCSemiCoherentSearch(MCMCSearch): BSGL=self.BSGL, nsegs=self.nsegs) return d - def initiate_search_object(self): + def _initiate_search_object(self): logging.info('Setting up search object') self.search = core.SemiCoherentSearch( label=self.label, outdir=self.outdir, tref=self.tref, @@ -1600,7 +1645,7 @@ class MCMCSemiCoherentSearch(MCMCSearch): injectSources=self.injectSources, assumeSqrtSX=self.assumeSqrtSX) def logp(self, theta_vals, theta_prior, theta_keys, search): - H = [self.generic_lnprior(**theta_prior[key])(p) for p, key in + H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in zip(theta_vals, theta_keys)] return np.sum(H) @@ -1614,7 +1659,7 @@ class MCMCSemiCoherentSearch(MCMCSearch): class MCMCFollowUpSearch(MCMCSemiCoherentSearch): """ A follow up procudure increasing the coherence time in a zoom """ - def get_save_data_dictionary(self): + def _get_data_dictionary_to_save(self): d = dict(nwalkers=self.nwalkers, ntemps=self.ntemps, theta_keys=self.theta_keys, theta_prior=self.theta_prior, scatter_val=self.scatter_val, @@ -1840,17 +1885,17 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): """ self.nsegs = 1 - self.initiate_search_object() + self._initiate_search_object() run_setup = self.init_run_setup( run_setup, R=R, Nsegs0=Nsegs0, log_table=log_table, gen_tex_table=gen_tex_table) self.run_setup = run_setup - self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use() + self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use() if self.old_data_is_okay_to_use is True: logging.warning('Using saved data from {}'.format( self.pickle_path)) - d = self.get_saved_data() + d = self.get_saved_data_dictionary() self.sampler = d['sampler'] self.samples = d['samples'] self.lnprobs = d['lnprobs'] @@ -1861,12 +1906,12 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): nsteps_total = 0 for j, ((nburn, nprod), nseg, reset_p0) in enumerate(run_setup): if j == 0: - p0 = self.generate_initial_p0() - p0 = self.apply_corrections_to_p0(p0) + p0 = self._generate_initial_p0() + p0 = self._apply_corrections_to_p0(p0) elif reset_p0: - p0 = self.get_new_p0(sampler) - p0 = self.apply_corrections_to_p0(p0) - # self.check_initial_points(p0) + p0 = self._get_new_p0(sampler) + p0 = self._apply_corrections_to_p0(p0) + # self._check_initial_points(p0) else: p0 = sampler.chain[:, :, -1, :] @@ -1884,7 +1929,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): logging.info(('Running {}/{} with {} steps and {} nsegs ' '(Tcoh={:1.2f} days)').format( j+1, len(run_setup), (nburn, nprod), nseg, Tcoh)) - sampler = self.run_sampler(sampler, p0, nburn=nburn, nprod=nprod) + sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod) logging.info("Mean acceptance fraction: {}" .format(np.mean(sampler.acceptance_fraction, axis=1))) if self.ntemps > 1: @@ -1894,7 +1939,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): np.max(sampler.lnlikelihood))) if create_plots: - fig, axes = self.plot_walkers( + fig, axes = self._plot_walkers( sampler, symbols=self.theta_symbols, fig=fig, axes=axes, nprod=nprod, xoffset=nsteps_total, **kwargs) for ax in axes[:self.ndim]: @@ -1909,7 +1954,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): self.samples = samples self.lnprobs = lnprobs self.lnlikes = lnlikes - self.save_data(sampler, samples, lnprobs, lnlikes) + self._save_data(sampler, samples, lnprobs, lnlikes) if create_plots: try: @@ -1945,7 +1990,7 @@ class MCMCTransientSearch(MCMCSearch): 'label': 'Transient start-time \n days after minStartTime'} ) - def initiate_search_object(self): + def _initiate_search_object(self): logging.info('Setting up search object') self.search = core.ComputeFstat( tref=self.tref, sftfilepath=self.sftfilepath, @@ -1965,7 +2010,7 @@ class MCMCTransientSearch(MCMCSearch): FS = search.run_computefstatistic_single_point(*in_theta) return FS - def unpack_input_theta(self): + def _unpack_input_theta(self): full_theta_keys = ['transient_tstart', 'transient_duration', 'F0', 'F1', 'F2', 'Alpha', 'Delta'] diff --git a/tests.py b/tests.py index 69e332025a69edf266e7f71f8ae4f496e84a59cc..503eaf1b53c1b2704accefd7e8e384855c162733 100644 --- a/tests.py +++ b/tests.py @@ -49,7 +49,7 @@ class TestBaseSearchClass(Test): def test_shift_matrix(self): BSC = pyfstat.BaseSearchClass() dT = 10 - a = BSC.shift_matrix(4, dT) + a = BSC._shift_matrix(4, dT) b = np.array([[1, 2*np.pi*dT, 2*np.pi*dT**2/2.0, 2*np.pi*dT**3/6.0], [0, 1, dT, dT**2/2.0], [0, 0, 1, dT], @@ -71,16 +71,16 @@ class TestBaseSearchClass(Test): self.assertTrue( np.array_equal( - thetaB, BSC.shift_coefficients(thetaA, dT))) + thetaB, BSC._shift_coefficients(thetaA, dT))) def test_shift_coefficients_loop(self): BSC = pyfstat.BaseSearchClass() thetaA = np.array([10., 1e2, 10., 1e2]) dT = 1e1 - thetaB = BSC.shift_coefficients(thetaA, dT) + thetaB = BSC._shift_coefficients(thetaA, dT) self.assertTrue( np.allclose( - thetaA, BSC.shift_coefficients(thetaB, -dT), + thetaA, BSC._shift_coefficients(thetaB, -dT), rtol=1e-9, atol=1e-9)) @@ -257,8 +257,7 @@ class TestAuxillaryFunctions(Test): DeltaFs = [1e-4, 1e-14] fiducial_freq = 100 detector_names = ['H1', 'L1'] - earth_ephem = pyfstat.earth_ephem - sun_ephem = pyfstat.sun_ephem + earth_ephem, sun_ephem = pyfstat.helper_functions.set_up_ephemeris_configuration() def test_get_V_estimate_sky_F0_F1(self):