test_recurrent.py 文件源码

python
阅读 44 收藏 0 点赞 0 评论 0

项目:NumpyDL 作者: oujago 项目源码 文件源码
def test_LSTM():
    for seq in (True, False):
        layer = LSTM(n_out=200, n_in=100, return_sequence=seq)
        assert layer.out_shape is None
        layer.connect_to()
        assert len(layer.out_shape) == (3 if seq else 2)

        input = np.random.rand(10, 50, 100)
        mask = np.random.randint(0, 2, (10, 50))
        assert np.ndim(layer.forward(input, mask)) == (3 if seq else 2)

        with pytest.raises(NotImplementedError):
            layer.backward(None)

        assert len(layer.params) == 12
        assert len(layer.grads) == 12
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号