def test_rolling_window(input_seq, batch_size, seq_len, strides):
# This test checks if the rolling window works
# We check if the first two samples in each batch are strided by strides
# Truncate input sequence such that last section that doesn't fit in a batch
# is thrown away
input_seq = input_seq[:seq_len * batch_size * (len(input_seq) // seq_len // batch_size)]
data_array = {'X': input_seq,
'y': np.roll(input_seq, axis=0, shift=-1)}
time_steps = seq_len
it_array = SequentialArrayIterator(data_arrays=data_array, time_steps=time_steps,
stride=strides, batch_size=batch_size, tgt_key='y',
shuffle=False)
for idx, iter_val in enumerate(it_array):
# Start of the array needs to be time_steps * idx
assert np.array_equal(iter_val['X'][0, strides:time_steps],
iter_val['X'][1, :time_steps - strides])
assert np.array_equal(iter_val['y'][0, strides:time_steps],
iter_val['y'][1, :time_steps - strides])
评论列表
文章目录