02_frequency_discrimination_task.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:skiprnn-2017-telecombcn 作者: imatge-upc 项目源码 文件源码
def generate_batch(batch_size, sampling_period, signal_duration, start_period, end_period,
                   start_target_period, end_target_period):
    """
    Generate a stratified batch of examples. There are two classes:
        class 0: sine waves with period in [start_target_period, end_target_period]
        class 1: sine waves with period in [start_period, start_target_period] U [end_target_period, end_period]
    :param batch_size: number of samples per batch
    :param sampling_period: sampling period in milliseconds
    :param signal_duration: duration of the sine waves in milliseconds

    :return x: batch of examples
    :return y: batch of labels
    """
    seq_length = int(signal_duration / sampling_period)

    n_elems = 1
    x = np.empty((batch_size, seq_length, n_elems))
    y = np.empty(batch_size, dtype=np.int64)

    t = np.linspace(0, signal_duration - sampling_period, seq_length)

    for idx in range(int(batch_size/2)):
        period = random.uniform(start_target_period, end_target_period)
        phase_shift = random.uniform(0, period)
        x[idx, :, 0] = generate_example(t, 1./period, phase_shift)
        y[idx] = 0
    for idx in range(int(batch_size/2), batch_size):
        period = random_disjoint_interval(start_period, end_period,
                                          start_target_period, end_target_period)
        phase_shift = random.uniform(0, period)
        x[idx, :, 0] = generate_example(t, 1./period, phase_shift)
        y[idx] = 1
    return x, y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号