def test_convolve_generalization():
ag_convolve = autograd.scipy.signal.convolve
A_35 = R(3, 5)
A_34 = R(3, 4)
A_342 = R(3, 4, 2)
A_2543 = R(2, 5, 4, 3)
A_24232 = R(2, 4, 2, 3, 2)
for mode in ['valid', 'full']:
assert npo.allclose(ag_convolve(A_35, A_34, axes=([1], [0]), mode=mode)[1, 2],
sp_convolve(A_35[1,:], A_34[:, 2], mode))
assert npo.allclose(ag_convolve(A_35, A_34, axes=([],[]), dot_axes=([0], [0]), mode=mode),
npo.tensordot(A_35, A_34, axes=([0], [0])))
assert npo.allclose(ag_convolve(A_35, A_342, axes=([1],[2]),
dot_axes=([0], [0]), mode=mode)[2],
sum([sp_convolve(A_35[i, :], A_342[i, 2, :], mode)
for i in range(3)]))
assert npo.allclose(ag_convolve(A_2543, A_24232, axes=([1, 2],[2, 4]),
dot_axes=([0, 3], [0, 3]), mode=mode)[2],
sum([sum([sp_convolve(A_2543[i, :, :, j],
A_24232[i, 2, :, j, :], mode)
for i in range(2)]) for j in range(3)]))
评论列表
文章目录