diff --git a/pyfstat/core.py b/pyfstat/core.py index 7a23a08a158cc588fc8f969809f746922f15a80c..46e6efd93ea789843b0ef394edfc593b37ea18bb 100755 --- a/pyfstat/core.py +++ b/pyfstat/core.py @@ -13,9 +13,8 @@ import lal import lalpulsar import helper_functions -tqdm = helper_functions.set_up_optional_tqdm() helper_functions.set_up_matplotlib_defaults() -args = helper_functions.set_up_command_line_arguments() +args, tqdm = helper_functions.set_up_command_line_arguments() earth_ephem, sun_ephem = helper_functions.set_up_ephemeris_configuration() @@ -672,7 +671,7 @@ class Writer(BaseSearchClass): delta_F2=0, tref=None, F0=30, F1=1e-10, F2=0, Alpha=5e-3, Delta=6e-2, h0=0.1, cosi=0.0, psi=0.0, phi=0, Tsft=1800, outdir=".", sqrtSX=1, Band=4, detector='H1', - minStartTime=None, maxStartTime=None): + minStartTime=None, maxStartTime=None, add_noise=False): """ Parameters ---------- @@ -701,19 +700,25 @@ class Writer(BaseSearchClass): for d in self.delta_phi, self.delta_F0, self.delta_F1, self.delta_F2: if np.size(d) == 1: - d = [d] + d = np.atleast_1d(d) self.tend = self.tstart + self.duration if self.minStartTime is None: self.minStartTime = self.tstart if self.maxStartTime is None: self.maxStartTime = self.tend - if self.dtglitch is None or self.dtglitch == self.duration: + if self.dtglitch is None or all(self.dtglitch == self.duration): self.tbounds = [self.tstart, self.tend] elif np.size(self.dtglitch) == 1: - self.tbounds = [self.tstart, self.tstart+self.dtglitch, self.tend] + self.dtglitch = np.array(dtglitch) + self.tbounds = np.concatenate(( + [self.tstart], self.tstart+self.dtglitch, [self.tend])) else: - self.tglitch = self.tstart + np.array(self.dtglitch) - self.tbounds = [self.tstart] + list(self.tglitch) + [self.tend] + self.dtglitch = np.array(dtglitch) + self.tglitch = self.tstart + self.dtglitch + self.tbounds = np.concatenate(( + [self.tstart], self.tglitch, [self.tend])) + + self.check_inputs() if os.path.isdir(self.outdir) is False: os.makedirs(self.outdir) @@ -736,6 +741,16 @@ class Writer(BaseSearchClass): self.sftfilepath = '{}/{}'.format(self.outdir, self.sftfilename) self.calculate_fmin_Band() + def check_inputs(self): + self.minStartTime = int(self.minStartTime) + self.maxStartTime = int(self.maxStartTime) + shapes = np.array([np.shape(x) for x in [self.delta_phi, self.delta_F0, + self.delta_F1, self.delta_F2]] + ) + if not np.all(shapes == shapes[0]): + raise ValueError('all delta_* must be the same shape: {}'.format( + shapes)) + def make_data(self): ''' A convienience wrapper to generate a cff file then sfts ''' self.make_cff() @@ -865,7 +880,8 @@ transientTauDays={:1.3f}\n""") cl.append('--outSFTdir="{}"'.format(self.outdir)) cl.append('--outLabel="{}"'.format(self.label)) cl.append('--IFOs="{}"'.format(self.detector)) - cl.append('--sqrtSX="{}"'.format(self.sqrtSX)) + if self.add_noise: + cl.append('--sqrtSX="{}"'.format(self.sqrtSX)) if self.minStartTime is None: cl.append('--startTime={:10.9f}'.format(float(self.tstart))) else: diff --git a/pyfstat/helper_functions.py b/pyfstat/helper_functions.py index 3034df7fad0d78956797dd14461658176bee23e4..0a809486e3cbf903c7fff71ae9a1df58fb0d7bc4 100644 --- a/pyfstat/helper_functions.py +++ b/pyfstat/helper_functions.py @@ -45,6 +45,8 @@ def set_up_command_line_arguments(): if args.quite or args.no_interactive: def tqdm(x, *args, **kwargs): return x + else: + tqdm = set_up_optional_tqdm() logger = logging.getLogger() logger.setLevel(logging.DEBUG) stream_handler = logging.StreamHandler() @@ -55,7 +57,7 @@ def set_up_command_line_arguments(): stream_handler.setFormatter(logging.Formatter( '%(asctime)s %(levelname)-8s: %(message)s', datefmt='%H:%M')) logger.addHandler(stream_handler) - return args + return args, tqdm def set_up_ephemeris_configuration():