test_negative_sampling.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:chainer-deconv 作者: germanRos 项目源码 文件源码
def check_backward(self, x_data, t_data, y_grad):
        x = chainer.Variable(x_data)
        t = chainer.Variable(t_data)
        W = self.link.W

        y = self.link(x, t)
        y.grad = y_grad
        y.backward()

        # fix samples
        negative_sampling.NegativeSamplingFunction.samples = y.creator.samples

        def f():
            return self.link(x, t).data,
        gx, gW = gradient_check.numerical_grad(
            f, (x.data, W.data), (y.grad,), eps=1e-2)
        del negative_sampling.NegativeSamplingFunction.samples  # clean up

        gradient_check.assert_allclose(
            cuda.to_cpu(gx), cuda.to_cpu(x.grad), atol=1.e-4)
        gradient_check.assert_allclose(
            cuda.to_cpu(gW), cuda.to_cpu(W.grad), atol=1.e-4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号