def generate_batch(batch_size, sampling_period, signal_duration, start_period, end_period,
start_target_period, end_target_period):
"""
Generate a stratified batch of examples. There are two classes:
class 0: sine waves with period in [start_target_period, end_target_period]
class 1: sine waves with period in [start_period, start_target_period] U [end_target_period, end_period]
:param batch_size: number of samples per batch
:param sampling_period: sampling period in milliseconds
:param signal_duration: duration of the sine waves in milliseconds
:return x: batch of examples
:return y: batch of labels
"""
seq_length = int(signal_duration / sampling_period)
n_elems = 1
x = np.empty((batch_size, seq_length, n_elems))
y = np.empty(batch_size, dtype=np.int64)
t = np.linspace(0, signal_duration - sampling_period, seq_length)
for idx in range(int(batch_size/2)):
period = random.uniform(start_target_period, end_target_period)
phase_shift = random.uniform(0, period)
x[idx, :, 0] = generate_example(t, 1./period, phase_shift)
y[idx] = 0
for idx in range(int(batch_size/2), batch_size):
period = random_disjoint_interval(start_period, end_period,
start_target_period, end_target_period)
phase_shift = random.uniform(0, period)
x[idx, :, 0] = generate_example(t, 1./period, phase_shift)
y[idx] = 1
return x, y
02_frequency_discrimination_task.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录