test_backends.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:indigo 作者: mbdriscoll 项目源码 文件源码
def test_fft(backend, batch, x, y, z):
    b = backend()
    N = (z, y, x, batch)
    v = np.random.rand(*N) + 1j*np.random.rand(*N)
    v = np.require(v, dtype=np.dtype('complex64'), requirements='F')
    ax = (0,1,2)

    # check forward
    w_exp = np.fft.fftn(v, axes=ax)
    v_d = b.copy_array(v)
    u_d = b.copy_array(v)
    b.fftn(u_d, v_d)
    w_act = u_d.to_host()
    np.testing.assert_allclose(w_act, w_exp, atol=1e-2)

    # check adjoint
    v_exp = np.fft.ifftn(w_act, axes=ax) * (x*y*z)
    v_d = b.copy_array(w_act)
    u_d = b.copy_array(w_act)
    b.ifftn(u_d, v_d)
    v_act = u_d.to_host()
    np.testing.assert_allclose(v_act, v_exp, atol=1e-2)

    # check unitary
    np.testing.assert_allclose(v, v_act / (x*y*z), atol=1e-6)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号