def test_index_group_func():
import numpy as np
import cupy as cp
from chainer import cuda
input = np.random.randn(2, 3, 4, 5, 6)
I = np.random.randint(0, 4, (7, 8, 9, 10))
J = np.random.randint(0, 5, (7, 8, 9, 10))
K = np.random.randint(0, 6, (7, 8, 9, 10))
output = input[..., I, J, K].swapaxes(1, 2)
cpoutput = cp.zeros(output.shape)
cpinput = cuda.to_gpu(input)
cpI = cuda.to_gpu(I)
cpJ = cuda.to_gpu(J)
cpK = cuda.to_gpu(K)
index_group_func_kernel(cpinput, cpI, cpJ, cpK, cpoutput)
cpoutput = cuda.to_cpu(cpoutput)
error = np.abs(cpoutput - output).sum()
print(error)
assert np.isclose(error, 0.)
test_integer_indexing_cuda_kernel.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录