test_transform_filter.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:GrouPy 作者: tscohen 项目源码 文件源码
def check_transform_grad(inds, w, transformer, dtype, toll):
    from chainer import gradient_check

    inds = cuda.to_gpu(inds)

    W = Variable(w.astype(dtype))
    R = transformer(inds)

    RW = R(W)

    RW.grad = cp.random.randn(*RW.data.shape).astype(dtype)
    RW.backward(retain_grad=True)

    func = RW.creator
    fn = lambda: func.forward((W.data,))
    gW, = gradient_check.numerical_grad(fn, (W.data,), (RW.grad,))

    gan = cuda.to_cpu(gW)
    gat = cuda.to_cpu(W.grad)

    relerr = np.max(np.abs(gan - gat) / np.maximum(np.abs(gan), np.abs(gat)))

    print (dtype, toll, relerr)
    assert relerr < toll
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号