def test_fft(backend, batch, x, y, z):
b = backend()
N = (z, y, x, batch)
v = np.random.rand(*N) + 1j*np.random.rand(*N)
v = np.require(v, dtype=np.dtype('complex64'), requirements='F')
ax = (0,1,2)
# check forward
w_exp = np.fft.fftn(v, axes=ax)
v_d = b.copy_array(v)
u_d = b.copy_array(v)
b.fftn(u_d, v_d)
w_act = u_d.to_host()
np.testing.assert_allclose(w_act, w_exp, atol=1e-2)
# check adjoint
v_exp = np.fft.ifftn(w_act, axes=ax) * (x*y*z)
v_d = b.copy_array(w_act)
u_d = b.copy_array(w_act)
b.ifftn(u_d, v_d)
v_act = u_d.to_host()
np.testing.assert_allclose(v_act, v_exp, atol=1e-2)
# check unitary
np.testing.assert_allclose(v, v_act / (x*y*z), atol=1e-6)
评论列表
文章目录