Commit f9952731 authored by Yifan Wang's avatar Yifan Wang
Browse files

complete the traning data part, also plot some figures

parent 18514b44
......@@ -139,15 +139,16 @@ class SampleFile:
# Read in dict with command_line_arguments
self.data['command_line_arguments'] = \
dict(hdf_file['command_line_arguments'].attrs)
self.data['command_line_arguments'] = \
{key: value.decode('ascii') for key, value in
{key: value for key, value in
iteritems(self.data['command_line_arguments'])}
# Read in dict with static_arguments
self.data['static_arguments'] = \
dict(hdf_file['static_arguments'].attrs)
self.data['static_arguments'] = \
{key: value.decode('ascii') for key, value in
{key: value for key, value in
iteritems(self.data['static_arguments'])}
# Read in group containing injection samples
......
......@@ -12,7 +12,7 @@ from __future__ import print_function
import numpy as np
from lal import LIGOTimeGPS
from pycbc.psd import interpolate
from pycbc.psd import interpolate, inverse_spectrum_truncation
from pycbc.psd.analytical import aLIGOZeroDetHighPower
from pycbc.noise import noise_from_psd
from pycbc.filter import sigma
......@@ -25,6 +25,22 @@ from .waveforms import get_detector_signals, get_waveform
# FUNCTION DEFINITIONS
# -----------------------------------------------------------------------------
def signal_whiten(psd, signal, segment_duration, max_filter_duration,
trunc_method='hann', low_frequency_cutoff=None):
# Estimate the noise spectrum, no need for segment_duration, because we already get in line 193.
psd = interpolate(psd, signal.delta_f)
max_filter_len = int(max_filter_duration * signal.sample_rate)
# Interpolate and smooth to the desired corruption length
psd = inverse_spectrum_truncation(psd,
max_filter_len=max_filter_len,
low_frequency_cutoff=low_frequency_cutoff,
trunc_method=trunc_method)
# Whiten the data by the asd
white = (signal.to_frequencyseries() / psd**0.5).to_timeseries()
return white
def generate_sample(static_arguments,
event_tuple,
waveform_params=None):
......@@ -133,6 +149,7 @@ def generate_sample(static_arguments,
if waveform_params is None:
detector_signals = None
injection_parameters = None
output_signals = None
strain = noise
# Otherwise, we need to simulate a waveform for the given waveform_params
......@@ -154,6 +171,10 @@ def generate_sample(static_arguments,
event_time=event_time,
waveform=waveform)
# Store the output_signal
output_signals = {}
output_signals = detector_signals.copy()
# ---------------------------------------------------------------------
# Add the waveform into the noise as is to calculate the NOMF-SNR
# ---------------------------------------------------------------------
......@@ -199,7 +220,7 @@ def generate_sample(static_arguments,
# injection SNR
strain[det] = noise[det].add_into(scale_factor *
detector_signals[det])
output_signals[det] = scale_factor * output_signals[det]
# ---------------------------------------------------------------------
# Store some information about the injection we just made
# ---------------------------------------------------------------------
......@@ -210,7 +231,7 @@ def generate_sample(static_arguments,
'l1_snr': snrs['L1']}
# Also add the waveform parameters we have sampled
for key, value in waveform_params.items():
for key, value in iter(waveform_params.items()):
injection_parameters[key] = value
# -------------------------------------------------------------------------
......@@ -231,6 +252,13 @@ def generate_sample(static_arguments,
max_filter_duration=max_filter_duration,
remove_corrupted=False)
if waveform_params is not None:
output_signals[det] = \
signal_whiten(psd = psds[det],
signal = output_signals[det],
segment_duration = segment_duration,
max_filter_duration = max_filter_duration)
# Get the limits for the bandpass
bandpass_lower = static_arguments['bandpass_lower']
bandpass_upper = static_arguments['bandpass_upper']
......@@ -241,14 +269,20 @@ def generate_sample(static_arguments,
strain[det] = strain[det].highpass_fir(frequency=bandpass_lower,
remove_corrupted=False,
order=512)
if waveform_params is not None:
output_signals[det] = output_signals[det].highpass_fir(frequency=bandpass_lower,
remove_corrupted=False,
order=512)
# Apply a low-pass filter to remove everything above `bandpass_upper`.
# If bandpass_upper = sampling rate, do not apply any low-pass filter.
if bandpass_upper != target_sampling_rate:
strain[det] = strain[det].lowpass_fir(frequency=bandpass_upper,
remove_corrupted=False,
order=512)
if waveform_params is not None:
output_signals[det] = output_signals[det].lowpass_fir(frequency=bandpass_upper,
remove_corrupted=False,
order=512)
# -------------------------------------------------------------------------
# Cut strain (and signal) time series to the pre-specified length
# -------------------------------------------------------------------------
......@@ -267,6 +301,7 @@ def generate_sample(static_arguments,
# Cut the detector signals to the specified length
detector_signals[det] = detector_signals[det].time_slice(a, b)
output_signals[det] = output_signals[det].time_slice(a, b)
# Also add the detector signals to the injection parameters
injection_parameters['h1_signal'] = \
......@@ -274,6 +309,11 @@ def generate_sample(static_arguments,
injection_parameters['l1_signal'] = \
np.array(detector_signals['L1'])
injection_parameters['h1_output_signal'] = \
np.array(output_signals['H1'])
injection_parameters['l1_output_signal'] = \
np.array(output_signals['L1'])
# -------------------------------------------------------------------------
# Collect all available information about this sample and return results
# -------------------------------------------------------------------------
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment