test_scipy.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:autograd-forward 作者: BB-UCL 项目源码 文件源码
def test_convolve_generalization():
    ag_convolve = autograd.scipy.signal.convolve
    A_35 = R(3, 5)
    A_34 = R(3, 4)
    A_342 = R(3, 4, 2)
    A_2543 = R(2, 5, 4, 3)
    A_24232 = R(2, 4, 2, 3, 2)

    for mode in ['valid', 'full']:
        assert npo.allclose(ag_convolve(A_35,      A_34, axes=([1], [0]), mode=mode)[1, 2],
                            sp_convolve(A_35[1,:], A_34[:, 2], mode))
        assert npo.allclose(ag_convolve(A_35, A_34, axes=([],[]), dot_axes=([0], [0]), mode=mode),
                            npo.tensordot(A_35, A_34, axes=([0], [0])))
        assert npo.allclose(ag_convolve(A_35, A_342, axes=([1],[2]),
                                        dot_axes=([0], [0]), mode=mode)[2],
                            sum([sp_convolve(A_35[i, :], A_342[i, 2, :], mode)
                                 for i in range(3)]))
        assert npo.allclose(ag_convolve(A_2543, A_24232, axes=([1, 2],[2, 4]),
                                        dot_axes=([0, 3], [0, 3]), mode=mode)[2],
                            sum([sum([sp_convolve(A_2543[i, :, :, j],
                                                 A_24232[i, 2, :, j, :], mode)
                                      for i in range(2)]) for j in range(3)]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号