test_pooling.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:NumpyDL 作者: oujago 项目源码 文件源码
def test_MaxPooling():
    from npdl.layers import MaxPooling

    pool = MaxPooling((2, 2))

    pool.connect_to(PreLayer((10, 1, 20, 30)))
    assert pool.out_shape == (10, 1, 10, 15)

    with pytest.raises(ValueError):
        pool.forward(np.random.rand(10, 10))

    with pytest.raises(ValueError):
        pool.backward(np.random.rand(10, 20))

    assert np.ndim(pool.forward(np.random.rand(10, 20, 30))) == 3
    assert np.ndim(pool.backward(np.random.rand(10, 20, 30))) == 3

    assert np.ndim(pool.forward(np.random.rand(10, 1, 20, 30))) == 4
    assert np.ndim(pool.backward(np.random.rand(10, 1, 20, 30))) == 4
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号