Skip to content
Snippets Groups Projects
Commit a2acfcff authored by Reinhard Prix's avatar Reinhard Prix
Browse files

apply 'black' coding style (stricter subset of PEP8)

- supply consistent flake8 settings in setup.cfg
parent 96e1046c
No related branches found
No related tags found
No related merge requests found
...@@ -16,64 +16,97 @@ import lal ...@@ -16,64 +16,97 @@ import lal
import lalpulsar import lalpulsar
# workaround for matplotlib on X-less remote logins # workaround for matplotlib on X-less remote logins
if 'DISPLAY' in os.environ: if "DISPLAY" in os.environ:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
else: else:
logging.info('No $DISPLAY environment variable found, so importing \ logging.info(
matplotlib.pyplot with non-interactive "Agg" backend.') 'No $DISPLAY environment variable found, so importing \
matplotlib.pyplot with non-interactive "Agg" backend.'
)
import matplotlib import matplotlib
matplotlib.use('Agg')
matplotlib.use("Agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def set_up_optional_tqdm(): def set_up_optional_tqdm():
try: try:
from tqdm import tqdm from tqdm import tqdm
except ImportError: except ImportError:
def tqdm(x, *args, **kwargs): def tqdm(x, *args, **kwargs):
return x return x
return tqdm return tqdm
def set_up_matplotlib_defaults(): def set_up_matplotlib_defaults():
plt.switch_backend('Agg') plt.switch_backend("Agg")
plt.rcParams['text.usetex'] = True plt.rcParams["text.usetex"] = True
plt.rcParams['axes.formatter.useoffset'] = False plt.rcParams["axes.formatter.useoffset"] = False
def set_up_command_line_arguments(): def set_up_command_line_arguments():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-v", "--verbose", action="store_true", parser.add_argument(
help="Increase output verbosity [logging.DEBUG]") "-v",
parser.add_argument("-q", "--quite", action="store_true", "--verbose",
help="Decrease output verbosity [logging.WARNING]") action="store_true",
parser.add_argument("--no-interactive", help="Don't use interactive", help="Increase output verbosity [logging.DEBUG]",
action="store_true") )
parser.add_argument("-c", "--clean", action="store_true", parser.add_argument(
help="Force clean data, never use cached data") "-q",
"--quite",
action="store_true",
help="Decrease output verbosity [logging.WARNING]",
)
parser.add_argument(
"--no-interactive", help="Don't use interactive", action="store_true"
)
parser.add_argument(
"-c",
"--clean",
action="store_true",
help="Force clean data, never use cached data",
)
fu_parser = parser.add_argument_group( fu_parser = parser.add_argument_group(
'follow-up options', 'Options related to MCMCFollowUpSearch') "follow-up options", "Options related to MCMCFollowUpSearch"
fu_parser.add_argument('-s', "--setup-only", action="store_true", )
help="Only generate the setup file, don't run")
fu_parser.add_argument( fu_parser.add_argument(
"--no-template-counting", action="store_true", "-s",
help="No counting of templates, useful if the setup is predefined") "--setup-only",
action="store_true",
help="Only generate the setup file, don't run",
)
fu_parser.add_argument(
"--no-template-counting",
action="store_true",
help="No counting of templates, useful if the setup is predefined",
)
parser.add_argument( parser.add_argument(
'-N', type=int, default=3, metavar='N', "-N",
help="Number of threads to use when running in parallel") type=int,
parser.add_argument('unittest_args', nargs='*') default=3,
metavar="N",
help="Number of threads to use when running in parallel",
)
parser.add_argument("unittest_args", nargs="*")
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
sys.argv[1:] = args.unittest_args sys.argv[1:] = args.unittest_args
if args.quite or args.no_interactive: if args.quite or args.no_interactive:
def tqdm(x, *args, **kwargs): def tqdm(x, *args, **kwargs):
return x return x
else: else:
tqdm = set_up_optional_tqdm() tqdm = set_up_optional_tqdm()
logger = logging.getLogger() logger = logging.getLogger()
stream_handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter( stream_handler.setFormatter(
'%(asctime)s %(levelname)-8s: %(message)s', datefmt='%H:%M')) logging.Formatter("%(asctime)s %(levelname)-8s: %(message)s", datefmt="%H:%M")
)
if args.quite: if args.quite:
logger.setLevel(logging.WARNING) logger.setLevel(logging.WARNING)
...@@ -91,39 +124,45 @@ def set_up_command_line_arguments(): ...@@ -91,39 +124,45 @@ def set_up_command_line_arguments():
def get_ephemeris_files(): def get_ephemeris_files():
""" Returns the earth_ephem and sun_ephem """ """ Returns the earth_ephem and sun_ephem """
config_file = os.path.expanduser('~')+'/.pyfstat.conf' config_file = os.path.expanduser("~") + "/.pyfstat.conf"
env_var = 'LALPULSAR_DATADIR' env_var = "LALPULSAR_DATADIR"
please = 'Please provide the ephemerides paths when initialising searches.' please = "Please provide the ephemerides paths when initialising searches."
if os.path.isfile(config_file): if os.path.isfile(config_file):
d = {} d = {}
with open(config_file, 'r') as f: with open(config_file, "r") as f:
for line in f: for line in f:
k, v = line.split('=') k, v = line.split("=")
k = k.replace(' ', '') k = k.replace(" ", "")
for item in [' ', "'", '"', '\n']: for item in [" ", "'", '"', "\n"]:
v = v.replace(item, '') v = v.replace(item, "")
d[k] = v d[k] = v
try: try:
earth_ephem = d['earth_ephem'] earth_ephem = d["earth_ephem"]
sun_ephem = d['sun_ephem'] sun_ephem = d["sun_ephem"]
except: except:
logging.warning('No [earth/sun]_ephem found in '+config_file+'. '+please) logging.warning(
"No [earth/sun]_ephem found in " + config_file + ". " + please
)
earth_ephem = None earth_ephem = None
sun_ephem = None sun_ephem = None
elif env_var in list(os.environ.keys()): elif env_var in list(os.environ.keys()):
earth_ephem = os.path.join(os.environ[env_var],'earth00-40-DE421.dat.gz') earth_ephem = os.path.join(os.environ[env_var], "earth00-40-DE421.dat.gz")
sun_ephem = os.path.join(os.environ[env_var],'sun00-40-DE421.dat.gz') sun_ephem = os.path.join(os.environ[env_var], "sun00-40-DE421.dat.gz")
if not (os.path.isfile(earth_ephem) and os.path.isfile(sun_ephem)): if not (os.path.isfile(earth_ephem) and os.path.isfile(sun_ephem)):
earth_ephem = os.path.join(os.environ[env_var],'earth00-19-DE421.dat.gz') earth_ephem = os.path.join(os.environ[env_var], "earth00-19-DE421.dat.gz")
sun_ephem = os.path.join(os.environ[env_var],'sun00-19-DE421.dat.gz') sun_ephem = os.path.join(os.environ[env_var], "sun00-19-DE421.dat.gz")
if not (os.path.isfile(earth_ephem) and os.path.isfile(sun_ephem)): if not (os.path.isfile(earth_ephem) and os.path.isfile(sun_ephem)):
logging.warning('No [earth/sun]00-[19/40]-DE421 ephemerides ' logging.warning(
'found in the '+os.environ[env_var]+' directory. '+please) "No [earth/sun]00-[19/40]-DE421 ephemerides "
"found in the " + os.environ[env_var] + " directory. " + please
)
earth_ephem = None earth_ephem = None
sun_ephem = None sun_ephem = None
else: else:
logging.warning('No '+config_file+' file or $'+env_var+' environment ' logging.warning(
'variable found. '+please) "No " + config_file + " file or $" + env_var + " environment "
"variable found. " + please
)
earth_ephem = None earth_ephem = None
sun_ephem = None sun_ephem = None
return earth_ephem, sun_ephem return earth_ephem, sun_ephem
...@@ -133,7 +172,7 @@ def round_to_n(x, n): ...@@ -133,7 +172,7 @@ def round_to_n(x, n):
if not x: if not x:
return 0 return 0
power = -int(np.floor(np.log10(abs(x)))) + (n - 1) power = -int(np.floor(np.log10(abs(x)))) + (n - 1)
factor = (10 ** power) factor = 10 ** power
return round(x * factor) / factor return round(x * factor) / factor
...@@ -150,7 +189,7 @@ def texify_float(x, d=2): ...@@ -150,7 +189,7 @@ def texify_float(x, d=2):
stem = np.round(x / 10 ** power, d) stem = np.round(x / 10 ** power, d)
if d == 1: if d == 1:
stem = int(stem) stem = int(stem)
return r'${}{{\times}}10^{{{}}}$'.format(stem, power) return r"${}{{\times}}10^{{{}}}$".format(stem, power)
def initializer(func): def initializer(func):
...@@ -176,7 +215,7 @@ def get_peak_values(frequencies, twoF, threshold_2F, F0=None, F0range=None): ...@@ -176,7 +215,7 @@ def get_peak_values(frequencies, twoF, threshold_2F, F0=None, F0range=None):
cut_idxs = np.abs(frequencies - F0) < F0range cut_idxs = np.abs(frequencies - F0) < F0range
frequencies = frequencies[cut_idxs] frequencies = frequencies[cut_idxs]
twoF = twoF[cut_idxs] twoF = twoF[cut_idxs]
idxs = peakutils.indexes(twoF, thres=1.*threshold_2F/np.max(twoF)) idxs = peakutils.indexes(twoF, thres=1.0 * threshold_2F / np.max(twoF))
F0maxs = frequencies[idxs] F0maxs = frequencies[idxs]
twoFmaxs = twoF[idxs] twoFmaxs = twoF[idxs]
freq_err = frequencies[1] - frequencies[0] freq_err = frequencies[1] - frequencies[0]
...@@ -184,9 +223,9 @@ def get_peak_values(frequencies, twoF, threshold_2F, F0=None, F0range=None): ...@@ -184,9 +223,9 @@ def get_peak_values(frequencies, twoF, threshold_2F, F0=None, F0range=None):
def get_comb_values(F0, frequencies, twoF, period, N=4): def get_comb_values(F0, frequencies, twoF, period, N=4):
if period == 'sidereal': if period == "sidereal":
period = 23 * 60 * 60 + 56 * 60 + 4.0616 period = 23 * 60 * 60 + 56 * 60 + 4.0616
elif period == 'terrestrial': elif period == "terrestrial":
period = 86400 period = 86400
freq_err = frequencies[1] - frequencies[0] freq_err = frequencies[1] - frequencies[0]
comb_frequencies = [n * 1 / period for n in range(-N, N + 1)] comb_frequencies = [n * 1 / period for n in range(-N, N + 1)]
...@@ -198,12 +237,13 @@ def compute_P_twoFstarcheck(twoFstarcheck, twoFcheck, M0, plot=False): ...@@ -198,12 +237,13 @@ def compute_P_twoFstarcheck(twoFstarcheck, twoFcheck, M0, plot=False):
""" Returns the unnormalised pdf of twoFstarcheck given twoFcheck """ """ Returns the unnormalised pdf of twoFstarcheck given twoFcheck """
upper = 4 + twoFstarcheck + 0.5 * (2 * (4 * M0 + 2 * twoFcheck)) upper = 4 + twoFstarcheck + 0.5 * (2 * (4 * M0 + 2 * twoFcheck))
rho2starcheck = np.linspace(1e-1, upper, 500) rho2starcheck = np.linspace(1e-1, upper, 500)
integrand = (ncx2.pdf(twoFstarcheck, 4*M0, rho2starcheck) integrand = ncx2.pdf(twoFstarcheck, 4 * M0, rho2starcheck) * ncx2.pdf(
* ncx2.pdf(twoFcheck, 4, rho2starcheck)) twoFcheck, 4, rho2starcheck
)
if plot: if plot:
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.plot(rho2starcheck, integrand) ax.plot(rho2starcheck, integrand)
fig.savefig('test') fig.savefig("test")
return np.trapz(integrand, rho2starcheck) return np.trapz(integrand, rho2starcheck)
...@@ -212,8 +252,11 @@ def compute_pstar(twoFcheck_obs, twoFstarcheck_obs, m0, plot=False): ...@@ -212,8 +252,11 @@ def compute_pstar(twoFcheck_obs, twoFstarcheck_obs, m0, plot=False):
upper = 4 + twoFcheck_obs + (2 * (4 * M0 + 2 * twoFcheck_obs)) upper = 4 + twoFcheck_obs + (2 * (4 * M0 + 2 * twoFcheck_obs))
twoFstarcheck_vals = np.linspace(1e-1, upper, 500) twoFstarcheck_vals = np.linspace(1e-1, upper, 500)
P_twoFstarcheck = np.array( P_twoFstarcheck = np.array(
[compute_P_twoFstarcheck(twoFstarcheck, twoFcheck_obs, M0) [
for twoFstarcheck in twoFstarcheck_vals]) compute_P_twoFstarcheck(twoFstarcheck, twoFcheck_obs, M0)
for twoFstarcheck in twoFstarcheck_vals
]
)
C = np.trapz(P_twoFstarcheck, twoFstarcheck_vals) C = np.trapz(P_twoFstarcheck, twoFstarcheck_vals)
idx = np.argmin(np.abs(twoFstarcheck_vals - twoFstarcheck_obs)) idx = np.argmin(np.abs(twoFstarcheck_vals - twoFstarcheck_obs))
if plot: if plot:
...@@ -221,7 +264,7 @@ def compute_pstar(twoFcheck_obs, twoFstarcheck_obs, m0, plot=False): ...@@ -221,7 +264,7 @@ def compute_pstar(twoFcheck_obs, twoFstarcheck_obs, m0, plot=False):
ax.plot(twoFstarcheck_vals, P_twoFstarcheck) ax.plot(twoFstarcheck_vals, P_twoFstarcheck)
ax.fill_between(twoFstarcheck_vals[: idx + 1], 0, P_twoFstarcheck[: idx + 1]) ax.fill_between(twoFstarcheck_vals[: idx + 1], 0, P_twoFstarcheck[: idx + 1])
ax.axvline(twoFstarcheck_vals[idx]) ax.axvline(twoFstarcheck_vals[idx])
fig.savefig('test') fig.savefig("test")
pstar_l = np.trapz(P_twoFstarcheck[: idx + 1] / C, twoFstarcheck_vals[: idx + 1]) pstar_l = np.trapz(P_twoFstarcheck[: idx + 1] / C, twoFstarcheck_vals[: idx + 1])
return 2 * np.min([pstar_l, 1 - pstar_l]) return 2 * np.min([pstar_l, 1 - pstar_l])
...@@ -239,28 +282,28 @@ def run_commandline(cl, log_level=20, raise_error=True, return_output=True): ...@@ -239,28 +282,28 @@ def run_commandline(cl, log_level=20, raise_error=True, return_output=True):
""" """
logging.log(log_level, 'Now executing: ' + cl) logging.log(log_level, "Now executing: " + cl)
if return_output: if return_output:
try: try:
out = subprocess.check_output(cl, # what to run out = subprocess.check_output(
cl, # what to run
stderr=subprocess.STDOUT, # catch errors stderr=subprocess.STDOUT, # catch errors
shell=True, # proper environment etc shell=True, # proper environment etc
universal_newlines=True, # properly display linebreaks in error/output printing universal_newlines=True, # properly display linebreaks in error/output printing
) )
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
logging.log(log_level, 'Execution failed: {}'.format(e.output)) logging.log(log_level, "Execution failed: {}".format(e.output))
if raise_error: if raise_error:
raise raise
else: else:
out = 0 out = 0
os.system('\n') os.system("\n")
return(out) return out
else: else:
process = subprocess.Popen(cl, shell=True) process = subprocess.Popen(cl, shell=True)
process.communicate() process.communicate()
def convert_array_to_gsl_matrix(array): def convert_array_to_gsl_matrix(array):
gsl_matrix = lal.gsl_matrix(*array.shape) gsl_matrix = lal.gsl_matrix(*array.shape)
gsl_matrix.data = array gsl_matrix.data = array
...@@ -270,8 +313,7 @@ def convert_array_to_gsl_matrix(array): ...@@ -270,8 +313,7 @@ def convert_array_to_gsl_matrix(array):
def get_sft_array(sftfilepattern, data_duration, F0, dF0): def get_sft_array(sftfilepattern, data_duration, F0, dF0):
""" Return the raw data from a set of sfts """ """ Return the raw data from a set of sfts """
SFTCatalog = lalpulsar.SFTdataFind( SFTCatalog = lalpulsar.SFTdataFind(sftfilepattern, lalpulsar.SFTConstraints())
sftfilepattern, lalpulsar.SFTConstraints())
MultiSFTs = lalpulsar.LoadMultiSFTs(SFTCatalog, F0 - dF0, F0 + dF0) MultiSFTs = lalpulsar.LoadMultiSFTs(SFTCatalog, F0 - dF0, F0 + dF0)
SFTs = MultiSFTs.data[0] SFTs = MultiSFTs.data[0]
data = [] data = []
...@@ -318,8 +360,9 @@ def get_covering_band(tref, tstart, tend, F0, F1, F2): ...@@ -318,8 +360,9 @@ def get_covering_band(tref, tstart, tend, F0, F1, F2):
return lalpulsar.CWSignalCoveringBand(tstart, tend, psr, 0, 0, 0) return lalpulsar.CWSignalCoveringBand(tstart, tend, psr, 0, 0, 0)
def twoFDMoffThreshold(twoFon, knee=400, twoFDMoffthreshold_below_threshold=62, def twoFDMoffThreshold(
prefactor=0.9, offset=0.5): twoFon, knee=400, twoFDMoffthreshold_below_threshold=62, prefactor=0.9, offset=0.5
):
""" Calculation of the 2F_DMoff threshold, see Eq 2 of arXiv:1707.5286 """ """ Calculation of the 2F_DMoff threshold, see Eq 2 of arXiv:1707.5286 """
if twoFon <= knee: if twoFon <= knee:
return twoFDMoffthreshold_below_threshold return twoFDMoffthreshold_below_threshold
......
...@@ -5,12 +5,13 @@ ...@@ -5,12 +5,13 @@
import numpy as np import numpy as np
import logging import logging
try: try:
from astropy import units as u from astropy import units as u
from astropy.coordinates import SkyCoord from astropy.coordinates import SkyCoord
from astropy.time import Time from astropy.time import Time
except ImportError: except ImportError:
logging.warning('Python module astropy not installed') logging.warning("Python module astropy not installed")
import lal import lal
# Assume Earth goes around Sun in a non-wobbling circle at constant speed; # Assume Earth goes around Sun in a non-wobbling circle at constant speed;
...@@ -20,20 +21,18 @@ import lal ...@@ -20,20 +21,18 @@ import lal
def _eqToEcl(alpha, delta): def _eqToEcl(alpha, delta):
source = SkyCoord(alpha*u.radian, delta*u.radian, frame='gcrs') source = SkyCoord(alpha * u.radian, delta * u.radian, frame="gcrs")
out = source.transform_to('geocentrictrueecliptic') out = source.transform_to("geocentrictrueecliptic")
return np.array([out.lon.radian, out.lat.radian]) return np.array([out.lon.radian, out.lat.radian])
def _eclToEq(lon, lat): def _eclToEq(lon, lat):
source = SkyCoord(lon*u.radian, lat*u.radian, source = SkyCoord(lon * u.radian, lat * u.radian, frame="geocentrictrueecliptic")
frame='geocentrictrueecliptic') out = source.transform_to("gcrs")
out = source.transform_to('gcrs')
return np.array([out.ra.radian, out.dec.radian]) return np.array([out.ra.radian, out.dec.radian])
def _calcDopplerWings( def _calcDopplerWings(s_freq, s_alpha, s_delta, lonStart, lonStop, numTimes=100):
s_freq, s_alpha, s_delta, lonStart, lonStop, numTimes=100):
e_longitudes = np.linspace(lonStart, lonStop, numTimes) e_longitudes = np.linspace(lonStart, lonStop, numTimes)
v_over_c = 2 * np.pi * lal.AU_SI / lal.YRSID_SI / lal.C_SI v_over_c = 2 * np.pi * lal.AU_SI / lal.YRSID_SI / lal.C_SI
s_lon, s_lat = _eqToEcl(s_alpha, s_delta) s_lon, s_lat = _eqToEcl(s_alpha, s_delta)
...@@ -54,8 +53,7 @@ def _calcSpindownWings(freq, fdot, minStartTime, maxStartTime): ...@@ -54,8 +53,7 @@ def _calcSpindownWings(freq, fdot, minStartTime, maxStartTime):
return 0.5 * timespan * np.abs(fdot) * np.array([-1, 1]) return 0.5 * timespan * np.abs(fdot) * np.array([-1, 1])
def get_frequency_range_of_signal(F0, F1, Alpha, Delta, minStartTime, def get_frequency_range_of_signal(F0, F1, Alpha, Delta, minStartTime, maxStartTime):
maxStartTime):
""" Calculate the frequency range that a signal will occupy """ Calculate the frequency range that a signal will occupy
Parameters Parameters
...@@ -78,8 +76,8 @@ def get_frequency_range_of_signal(F0, F1, Alpha, Delta, minStartTime, ...@@ -78,8 +76,8 @@ def get_frequency_range_of_signal(F0, F1, Alpha, Delta, minStartTime,
YEAR_IN_DAYS = lal.YRSID_SI / lal.DAYSID_SI YEAR_IN_DAYS = lal.YRSID_SI / lal.DAYSID_SI
tEquinox = 79 tEquinox = 79
minStartTime_t = Time(minStartTime, format='gps').to_datetime().timetuple() minStartTime_t = Time(minStartTime, format="gps").to_datetime().timetuple()
maxStartTime_t = Time(maxStartTime, format='gps').to_datetime().timetuple() maxStartTime_t = Time(maxStartTime, format="gps").to_datetime().timetuple()
tStart_days = minStartTime_t.tm_yday - tEquinox tStart_days = minStartTime_t.tm_yday - tEquinox
tStop_days = maxStartTime_t.tm_yday - tEquinox tStop_days = maxStartTime_t.tm_yday - tEquinox
tStop_days += (maxStartTime_t.tm_year - minStartTime_t.tm_year) * YEAR_IN_DAYS tStop_days += (maxStartTime_t.tm_year - minStartTime_t.tm_year) * YEAR_IN_DAYS
......
This diff is collapsed.
This diff is collapsed.
...@@ -14,8 +14,8 @@ import pyfstat.helper_functions as helper_functions ...@@ -14,8 +14,8 @@ import pyfstat.helper_functions as helper_functions
def get_optimal_setup( def get_optimal_setup(
NstarMax, Nsegs0, tref, minStartTime, maxStartTime, prior, NstarMax, Nsegs0, tref, minStartTime, maxStartTime, prior, detector_names
detector_names): ):
""" Using an optimisation step, calculate the optimal setup ladder """ Using an optimisation step, calculate the optimal setup ladder
Parameters Parameters
...@@ -37,14 +37,14 @@ def get_optimal_setup( ...@@ -37,14 +37,14 @@ def get_optimal_setup(
""" """
logging.info('Calculating optimal setup for NstarMax={}, Nsegs0={}'.format( logging.info(
NstarMax, Nsegs0)) "Calculating optimal setup for NstarMax={}, Nsegs0={}".format(NstarMax, Nsegs0)
)
Nstar_0 = get_Nstar_estimate( Nstar_0 = get_Nstar_estimate(
Nsegs0, tref, minStartTime, maxStartTime, prior, Nsegs0, tref, minStartTime, maxStartTime, prior, detector_names
detector_names) )
logging.info( logging.info("Stage {}, nsegs={}, Nstar={}".format(0, Nsegs0, int(Nstar_0)))
'Stage {}, nsegs={}, Nstar={}'.format(0, Nsegs0, int(Nstar_0)))
nsegs_vals = [Nsegs0] nsegs_vals = [Nsegs0]
Nstar_vals = [Nstar_0] Nstar_vals = [Nstar_0]
...@@ -53,25 +53,27 @@ def get_optimal_setup( ...@@ -53,25 +53,27 @@ def get_optimal_setup(
nsegs_i = Nsegs0 nsegs_i = Nsegs0
while nsegs_i > 1: while nsegs_i > 1:
nsegs_i, Nstar_i = _get_nsegs_ip1( nsegs_i, Nstar_i = _get_nsegs_ip1(
nsegs_i, NstarMax, tref, minStartTime, maxStartTime, prior, nsegs_i, NstarMax, tref, minStartTime, maxStartTime, prior, detector_names
detector_names) )
nsegs_vals.append(nsegs_i) nsegs_vals.append(nsegs_i)
Nstar_vals.append(Nstar_i) Nstar_vals.append(Nstar_i)
i += 1 i += 1
logging.info( logging.info("Stage {}, nsegs={}, Nstar={}".format(i, nsegs_i, int(Nstar_i)))
'Stage {}, nsegs={}, Nstar={}'.format(i, nsegs_i, int(Nstar_i)))
return nsegs_vals, Nstar_vals return nsegs_vals, Nstar_vals
def _get_nsegs_ip1(nsegs_i, NstarMax, tref, minStartTime, maxStartTime, prior, def _get_nsegs_ip1(
detector_names): nsegs_i, NstarMax, tref, minStartTime, maxStartTime, prior, detector_names
):
""" Calculate Nsegs_{i+1} given Nsegs_{i} """ """ Calculate Nsegs_{i+1} given Nsegs_{i} """
log10NstarMax = np.log10(NstarMax) log10NstarMax = np.log10(NstarMax)
log10Nstari = np.log10(get_Nstar_estimate( log10Nstari = np.log10(
nsegs_i, tref, minStartTime, maxStartTime, prior, get_Nstar_estimate(
detector_names)) nsegs_i, tref, minStartTime, maxStartTime, prior, detector_names
)
)
def f(nsegs_ip1): def f(nsegs_ip1):
if nsegs_ip1[0] > nsegs_i: if nsegs_ip1[0] > nsegs_i:
...@@ -82,24 +84,30 @@ def _get_nsegs_ip1(nsegs_i, NstarMax, tref, minStartTime, maxStartTime, prior, ...@@ -82,24 +84,30 @@ def _get_nsegs_ip1(nsegs_i, NstarMax, tref, minStartTime, maxStartTime, prior,
if nsegs_ip1 == 0: if nsegs_ip1 == 0:
nsegs_ip1 = 1 nsegs_ip1 = 1
Nstarip1 = get_Nstar_estimate( Nstarip1 = get_Nstar_estimate(
nsegs_ip1, tref, minStartTime, maxStartTime, prior, detector_names) nsegs_ip1, tref, minStartTime, maxStartTime, prior, detector_names
)
if Nstarip1 is None: if Nstarip1 is None:
return 1e6 return 1e6
else: else:
log10Nstarip1 = np.log10(Nstarip1) log10Nstarip1 = np.log10(Nstarip1)
return np.abs(log10Nstari + log10NstarMax - log10Nstarip1) return np.abs(log10Nstari + log10NstarMax - log10Nstarip1)
res = scipy.optimize.minimize(f, .4*nsegs_i, method='Powell', tol=1,
options={'maxiter': 10}) res = scipy.optimize.minimize(
logging.info('{} with {} evaluations'.format(res['message'], res['nfev'])) f, 0.4 * nsegs_i, method="Powell", tol=1, options={"maxiter": 10}
)
logging.info("{} with {} evaluations".format(res["message"], res["nfev"]))
nsegs_ip1 = int(res.x) nsegs_ip1 = int(res.x)
if nsegs_ip1 == 0: if nsegs_ip1 == 0:
nsegs_ip1 = 1 nsegs_ip1 = 1
if res.success: if res.success:
return nsegs_ip1, get_Nstar_estimate( return (
nsegs_ip1, tref, minStartTime, maxStartTime, prior, nsegs_ip1,
detector_names) get_Nstar_estimate(
nsegs_ip1, tref, minStartTime, maxStartTime, prior, detector_names
),
)
else: else:
raise ValueError('Optimisation unsuccesful') raise ValueError("Optimisation unsuccesful")
def _extract_data_from_prior(prior): def _extract_data_from_prior(prior):
...@@ -121,7 +129,7 @@ def _extract_data_from_prior(prior): ...@@ -121,7 +129,7 @@ def _extract_data_from_prior(prior):
Fidicual frequency Fidicual frequency
""" """
keys = ['Alpha', 'Delta', 'F0', 'F1', 'F2'] keys = ["Alpha", "Delta", "F0", "F1", "F2"]
spindown_keys = keys[3:] spindown_keys = keys[3:]
sky_keys = keys[:2] sky_keys = keys[:2]
lims = [] lims = []
...@@ -129,14 +137,14 @@ def _extract_data_from_prior(prior): ...@@ -129,14 +137,14 @@ def _extract_data_from_prior(prior):
lims_idxs = [] lims_idxs = []
for i, key in enumerate(keys): for i, key in enumerate(keys):
if type(prior[key]) == dict: if type(prior[key]) == dict:
if prior[key]['type'] == 'unif': if prior[key]["type"] == "unif":
lims.append([prior[key]['lower'], prior[key]['upper']]) lims.append([prior[key]["lower"], prior[key]["upper"]])
lims_keys.append(key) lims_keys.append(key)
lims_idxs.append(i) lims_idxs.append(i)
else: else:
raise ValueError( raise ValueError(
"Prior type {} not yet supported".format( "Prior type {} not yet supported".format(prior[key]["type"])
prior[key]['type'])) )
elif key not in spindown_keys: elif key not in spindown_keys:
lims.append([prior[key], 0]) lims.append([prior[key], 0])
lims = np.array(lims) lims = np.array(lims)
...@@ -149,16 +157,15 @@ def _extract_data_from_prior(prior): ...@@ -149,16 +157,15 @@ def _extract_data_from_prior(prior):
p.append(basex) p.append(basex)
spindowns = np.sum([np.sum(lims_keys == k) for k in spindown_keys]) spindowns = np.sum([np.sum(lims_keys == k) for k in spindown_keys])
sky = any([key in lims_keys for key in sky_keys]) sky = any([key in lims_keys for key in sky_keys])
if type(prior['F0']) == dict: if type(prior["F0"]) == dict:
fiducial_freq = prior['F0']['upper'] fiducial_freq = prior["F0"]["upper"]
else: else:
fiducial_freq = prior['F0'] fiducial_freq = prior["F0"]
return np.array(p).T, spindowns, sky, fiducial_freq return np.array(p).T, spindowns, sky, fiducial_freq
def get_Nstar_estimate( def get_Nstar_estimate(nsegs, tref, minStartTime, maxStartTime, prior, detector_names):
nsegs, tref, minStartTime, maxStartTime, prior, detector_names):
""" Returns N* estimated from the super-sky metric """ Returns N* estimated from the super-sky metric
Parameters Parameters
...@@ -194,24 +201,30 @@ def get_Nstar_estimate( ...@@ -194,24 +201,30 @@ def get_Nstar_estimate(
ref_time = lal.LIGOTimeGPS(tref) ref_time = lal.LIGOTimeGPS(tref)
segments = lal.SegListCreate() segments = lal.SegListCreate()
for j in range(len(tboundaries) - 1): for j in range(len(tboundaries) - 1):
seg = lal.SegCreate(lal.LIGOTimeGPS(tboundaries[j]), seg = lal.SegCreate(
lal.LIGOTimeGPS(tboundaries[j+1]), lal.LIGOTimeGPS(tboundaries[j]), lal.LIGOTimeGPS(tboundaries[j + 1]), j
j) )
lal.SegListAppend(segments, seg) lal.SegListAppend(segments, seg)
detNames = lal.CreateStringVector(*detector_names) detNames = lal.CreateStringVector(*detector_names)
detectors = lalpulsar.MultiLALDetector() detectors = lalpulsar.MultiLALDetector()
lalpulsar.ParseMultiLALDetector(detectors, detNames) lalpulsar.ParseMultiLALDetector(detectors, detNames)
detector_weights = None detector_weights = None
detector_motion = (lalpulsar.DETMOTION_SPIN detector_motion = lalpulsar.DETMOTION_SPIN + lalpulsar.DETMOTION_ORBIT
+ lalpulsar.DETMOTION_ORBIT)
ephemeris = lalpulsar.InitBarycenter(earth_ephem, sun_ephem) ephemeris = lalpulsar.InitBarycenter(earth_ephem, sun_ephem)
try: try:
SSkyMetric = lalpulsar.ComputeSuperskyMetrics( SSkyMetric = lalpulsar.ComputeSuperskyMetrics(
lalpulsar.SUPERSKY_METRIC_TYPE, spindowns, ref_time, segments, lalpulsar.SUPERSKY_METRIC_TYPE,
fiducial_freq, detectors, detector_weights, detector_motion, spindowns,
ephemeris) ref_time,
segments,
fiducial_freq,
detectors,
detector_weights,
detector_motion,
ephemeris,
)
except RuntimeError as e: except RuntimeError as e:
logging.warning('Encountered run-time error {}'.format(e)) logging.warning("Encountered run-time error {}".format(e))
raise RuntimeError("Calculation of the SSkyMetric failed") raise RuntimeError("Calculation of the SSkyMetric failed")
if sky: if sky:
...@@ -220,7 +233,8 @@ def get_Nstar_estimate( ...@@ -220,7 +233,8 @@ def get_Nstar_estimate(
i = 2 i = 2
lalpulsar.ConvertPhysicalToSuperskyPoints( lalpulsar.ConvertPhysicalToSuperskyPoints(
out_rssky, in_phys, SSkyMetric.semi_rssky_transf) out_rssky, in_phys, SSkyMetric.semi_rssky_transf
)
d = out_rssky.data d = out_rssky.data
...@@ -234,6 +248,9 @@ def get_Nstar_estimate( ...@@ -234,6 +248,9 @@ def get_Nstar_estimate(
dV = np.abs(np.linalg.det(parallelepiped[:j, :j])) dV = np.abs(np.linalg.det(parallelepiped[:j, :j]))
sqrtdetG = np.sqrt(np.abs(np.linalg.det(g[:j, :j]))) sqrtdetG = np.sqrt(np.abs(np.linalg.det(g[:j, :j])))
Nstars.append(sqrtdetG * dV) Nstars.append(sqrtdetG * dV)
logging.debug('Nstar for each dimension = {}'.format( logging.debug(
', '.join(["{:1.1e}".format(n) for n in Nstars]))) "Nstar for each dimension = {}".format(
", ".join(["{:1.1e}".format(n) for n in Nstars])
)
)
return np.max(Nstars) return np.max(Nstars)
This diff is collapsed.
...@@ -8,3 +8,4 @@ tqdm ...@@ -8,3 +8,4 @@ tqdm
bashplotlib bashplotlib
peakutils peakutils
pathos pathos
pycuda
[metadata]
license_file = LICENSE.txt
[flake8]
exclude = .git,docs,build,dist,tests,*__init__.py
max-line-length = 80
select = C,E,F,W,B,B950
ignore = E501,W503,E203
...@@ -5,27 +5,32 @@ from os import path ...@@ -5,27 +5,32 @@ from os import path
here = path.abspath(path.dirname(__file__)) here = path.abspath(path.dirname(__file__))
# Get the long description from the README file # Get the long description from the README file
with open(path.join(here, 'README.md'), encoding='utf-8') as f: with open(path.join(here, "README.md"), encoding="utf-8") as f:
long_description = f.read() long_description = f.read()
setup(name='PyFstat', setup(
version='0.2', name="PyFstat",
author='Gregory Ashton', version="0.2",
author_email='gregory.ashton@ligo.org', author="Gregory Ashton",
author_email="gregory.ashton@ligo.org",
packages=find_packages(where="pyfstat"), packages=find_packages(where="pyfstat"),
include_package_data=True, include_package_data=True,
package_data={'pyfstat': ['pyCUDAkernels/cudaTransientFstatExpWindow.cu', package_data={
'pyCUDAkernels/cudaTransientFstatRectWindow.cu']}, "pyfstat": [
"pyCUDAkernels/cudaTransientFstatExpWindow.cu",
"pyCUDAkernels/cudaTransientFstatRectWindow.cu",
]
},
install_requires=[ install_requires=[
'matplotlib', "matplotlib",
'scipy', "scipy",
'ptemcee', "ptemcee",
'corner', "corner",
'dill', "dill",
'tqdm', "tqdm",
'bashplotlib', "bashplotlib",
'peakutils', "peakutils",
'pathos', "pathos",
'pycuda', "pycuda",
], ],
) )
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment