def test_LSTM():
for seq in (True, False):
layer = LSTM(n_out=200, n_in=100, return_sequence=seq)
assert layer.out_shape is None
layer.connect_to()
assert len(layer.out_shape) == (3 if seq else 2)
input = np.random.rand(10, 50, 100)
mask = np.random.randint(0, 2, (10, 50))
assert np.ndim(layer.forward(input, mask)) == (3 if seq else 2)
with pytest.raises(NotImplementedError):
layer.backward(None)
assert len(layer.params) == 12
assert len(layer.grads) == 12
评论列表
文章目录