test_arrayiterator.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_rolling_window(input_seq, batch_size, seq_len, strides):
    # This test checks if the rolling window works
    # We check if the first two samples in each batch are strided by strides

    # Truncate input sequence such that last section that doesn't fit in a batch
    # is thrown away
    input_seq = input_seq[:seq_len * batch_size * (len(input_seq) // seq_len // batch_size)]
    data_array = {'X': input_seq,
                  'y': np.roll(input_seq, axis=0, shift=-1)}
    time_steps = seq_len
    it_array = SequentialArrayIterator(data_arrays=data_array, time_steps=time_steps,
                                       stride=strides, batch_size=batch_size, tgt_key='y',
                                       shuffle=False)
    for idx, iter_val in enumerate(it_array):
        # Start of the array needs to be time_steps * idx
        assert np.array_equal(iter_val['X'][0, strides:time_steps],
                              iter_val['X'][1, :time_steps - strides])
        assert np.array_equal(iter_val['y'][0, strides:time_steps],
                              iter_val['y'][1, :time_steps - strides])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号