Skip to content
Snippets Groups Projects
Commit d6edafae authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adds initial check of the typical number of segments

The motivating idea here is to give the user an idea of how well the
MCMC is likely to perform: if the typical number of templates is of
order millions or more then the MCMC will work like a random template
back, jumping back and forth without finding a peak (for a long time).
parent fdc3033a
No related branches found
No related tags found
No related merge requests found
...@@ -53,11 +53,11 @@ parser.add_argument("-q", "--quite", help="Decrease output verbosity", ...@@ -53,11 +53,11 @@ parser.add_argument("-q", "--quite", help="Decrease output verbosity",
parser.add_argument("-c", "--clean", help="Don't use cached data", parser.add_argument("-c", "--clean", help="Don't use cached data",
action="store_true") action="store_true")
parser.add_argument("-u", "--use-old-data", action="store_true") parser.add_argument("-u", "--use-old-data", action="store_true")
parser.add_argument('-s', "--setup-only", action="store_true")
parser.add_argument('unittest_args', nargs='*') parser.add_argument('unittest_args', nargs='*')
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
sys.argv[1:] = args.unittest_args sys.argv[1:] = args.unittest_args
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
...@@ -268,6 +268,7 @@ class ComputeFstat(object): ...@@ -268,6 +268,7 @@ class ComputeFstat(object):
self.sftfilepath)) self.sftfilepath))
SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepath, constraints) SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepath, constraints)
names = list(set([d.header.name for d in SFTCatalog.data])) names = list(set([d.header.name for d in SFTCatalog.data]))
self.names = names
SFT_timestamps = [d.header.epoch for d in SFTCatalog.data] SFT_timestamps = [d.header.epoch for d in SFTCatalog.data]
try: try:
from bashplotlib.histogram import plot_hist from bashplotlib.histogram import plot_hist
...@@ -285,7 +286,7 @@ class ComputeFstat(object): ...@@ -285,7 +286,7 @@ class ComputeFstat(object):
int(SFT_timestamps[0])), shell=True).rstrip('\n'), int(SFT_timestamps[0])), shell=True).rstrip('\n'),
int(SFT_timestamps[-1]), int(SFT_timestamps[-1]),
subprocess.check_output('lalapps_tconvert {}'.format( subprocess.check_output('lalapps_tconvert {}'.format(
int(SFT_timestamps[-1])), shell=True)).rstrip('\n')) int(SFT_timestamps[-1])), shell=True).rstrip('\n')))
logging.info('Initialising ephems') logging.info('Initialising ephems')
ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem) ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem)
...@@ -1581,8 +1582,8 @@ class MCMCSearch(BaseSearchClass): ...@@ -1581,8 +1582,8 @@ class MCMCSearch(BaseSearchClass):
tglitches = [d['tglitch']] tglitches = [d['tglitch']]
else: else:
tglitches = [d['tglitch_{}'.format(i)] for i in range(self.nglitch)] tglitches = [d['tglitch_{}'.format(i)] for i in range(self.nglitch)]
tbounderies = [self.tstart] + tglitches + [self.tend] tboundaries = [self.tstart] + tglitches + [self.tend]
deltaTs = np.diff(tbounderies) deltaTs = np.diff(tboundaries)
ntrials = [time_trials + delta_F0 * dT for dT in deltaTs] ntrials = [time_trials + delta_F0 * dT for dT in deltaTs]
p_val = self.p_val_twoFhat(max_twoF, ntrials) p_val = self.p_val_twoFhat(max_twoF, ntrials)
print('p-value = {}'.format(p_val)) print('p-value = {}'.format(p_val))
...@@ -1862,11 +1863,11 @@ _ sftfilepath: str ...@@ -1862,11 +1863,11 @@ _ sftfilepath: str
delta_F0s[:self.theta0_idx] *= -1 delta_F0s[:self.theta0_idx] *= -1
tglitches = [d['tglitch']] tglitches = [d['tglitch']]
tbounderies = [self.tstart] + tglitches + [self.tend] tboundaries = [self.tstart] + tglitches + [self.tend]
for j in range(self.nglitch+1): for j in range(self.nglitch+1):
ts = tbounderies[j] ts = tboundaries[j]
te = tbounderies[j+1] te = tboundaries[j+1]
if (te - ts)/86400 < 5: if (te - ts)/86400 < 5:
logging.info('Period too short to perform cumulative search') logging.info('Period too short to perform cumulative search')
continue continue
...@@ -1960,6 +1961,70 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): ...@@ -1960,6 +1961,70 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
BSGL=self.BSGL, run_setup=self.run_setup) BSGL=self.BSGL, run_setup=self.run_setup)
return d return d
def get_width_from_prior(self, prior, key):
if prior[key]['type'] == 'unif':
return prior[key]['upper'] - prior[key]['lower']
def get_mid_from_prior(self, prior, key):
if prior[key]['type'] == 'unif':
return .5*(prior[key]['upper'] + prior[key]['lower'])
def get_number_of_templates_estimate(self, nsegs):
tboundaries = np.linspace(self.minStartTime, self.maxStartTime,
nsegs+1)
if 'Alpha' in self.theta_keys:
DeltaAlpha = self.get_width_from_prior(self.theta_prior, 'Alpha')
DeltaDelta = self.get_width_from_prior(self.theta_prior, 'Delta')
DeltaMid = self.get_mid_from_prior(self.theta_prior, 'Delta')
DeltaOmega = np.sin(DeltaMid) * DeltaDelta * DeltaAlpha
if 'F0' in self.theta_keys:
DeltaF0 = self.get_width_from_prior(self.theta_prior, 'F0')
else:
DeltaF0 = 1
if 'F1' in self.theta_keys:
DeltaF1 = self.get_width_from_prior(self.theta_prior, 'F1')
spindowns = 1
else:
DeltaF1 = 1
spindowns = 0
ref_time = lal.LIGOTimeGPS(self.tref)
segments = lal.SegListCreate()
for j in range(len(tboundaries)-1):
seg = lal.SegCreate(lal.LIGOTimeGPS(tboundaries[j]),
lal.LIGOTimeGPS(tboundaries[j+1]),
j)
lal.SegListAppend(segments, seg)
if type(self.theta_prior['F0']) == dict:
fiducial_freq = self.get_mid_from_prior(self.theta_prior, 'F0')
else:
fiducial_freq = self.theta_prior['F0']
detector_names = self.search.names
detNames = lal.CreateStringVector(*detector_names)
detectors = lalpulsar.MultiLALDetector()
lalpulsar.ParseMultiLALDetector(detectors, detNames)
detector_weights = None
detector_motion = (lalpulsar.DETMOTION_SPIN
+ lalpulsar.DETMOTION_ORBIT)
ephemeris = lalpulsar.InitBarycenter(self.earth_ephem,
self.sun_ephem)
try:
SSkyMetric = lalpulsar.ComputeSuperskyMetrics(
spindowns, ref_time, segments, fiducial_freq, detectors,
detector_weights, detector_motion, ephemeris)
sqrtdetG_SKY = np.sqrt(np.linalg.det(
SSkyMetric.semi_rssky_metric.data[:2, :2]))
sqrtdetG_F = np.sqrt(np.linalg.det(
SSkyMetric.semi_rssky_metric.data[2:, 2:]))
return '{:1.0e} = 1/2 x {:1.0e} x {:1.0e}'.format(
.5*sqrtdetG_SKY*sqrtdetG_F*DeltaOmega*DeltaF1*DeltaF0,
sqrtdetG_SKY*DeltaOmega,
sqrtdetG_F*DeltaF1*DeltaF0)
except RuntimeError:
return 'N/A'
elif self.theta_keys == ['F0', 'F1']:
return 'N/A'
def run(self, run_setup, proposal_scale_factor=2): def run(self, run_setup, proposal_scale_factor=2):
""" Run the follow-up with the given run_setup """ Run the follow-up with the given run_setup
...@@ -1969,8 +2034,12 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): ...@@ -1969,8 +2034,12 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
""" """
self.nsegs = 1
self.inititate_search_object()
logging.info('Using run-setup as follow:') logging.info('Using run-setup as follow:')
logging.info('Stage | nburn | nprod | nsegs | resetp0') logging.info('Stage | nburn | nprod | nsegs | resetp0 |'
'# templates = # sky x # Freq')
for i, rs in enumerate(run_setup): for i, rs in enumerate(run_setup):
rs = list(rs) rs = list(rs)
if len(rs) == 2: if len(rs) == 2:
...@@ -1978,9 +2047,15 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): ...@@ -1978,9 +2047,15 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
if np.shape(rs[0]) == (): if np.shape(rs[0]) == ():
rs[0] = (rs[0], 0) rs[0] = (rs[0], 0)
run_setup[i] = rs run_setup[i] = rs
logging.info('{} | {} | {} | {} | {}'.format( N = self.get_number_of_templates_estimate(rs[1])
logging.info('{} | {} | {} | {} | {} | {}'.format(
str(i).ljust(5), str(rs[0][0]).ljust(5), str(i).ljust(5), str(rs[0][0]).ljust(5),
str(rs[0][1]).ljust(5), str(rs[1]).ljust(5), rs[2])) str(rs[0][1]).ljust(5), str(rs[1]).ljust(5),
str(rs[2]).ljust(7), N))
if args.setup_only:
logging.info("Exit as requested by setup_only flag")
sys.exit()
self.run_setup = run_setup self.run_setup = run_setup
self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use() self.old_data_is_okay_to_use = self.check_old_data_is_okay_to_use()
...@@ -1998,8 +2073,6 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): ...@@ -1998,8 +2073,6 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
fig = None fig = None
axes = None axes = None
nsteps_total = 0 nsteps_total = 0
self.nsegs = 1
self.inititate_search_object()
for j, ((nburn, nprod), nseg, reset_p0) in enumerate(run_setup): for j, ((nburn, nprod), nseg, reset_p0) in enumerate(run_setup):
if j == 0: if j == 0:
p0 = self.generate_initial_p0() p0 = self.generate_initial_p0()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment