test_autograd.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def _test_setitem_tensor(self, size, index):
        x = Variable(torch.ones(*size), requires_grad=True)
        y = x + 2
        y_version = y._version
        value = Variable(torch.Tensor(x[index].size()).fill_(7), requires_grad=True)
        y[index] = value
        self.assertNotEqual(y._version, y_version)
        y.backward(torch.ones(*size))
        expected_grad_input = torch.ones(*size)
        if isinstance(index, Variable):
            index = index.data
        expected_grad_input[index] = 0
        self.assertEqual(x.grad.data, expected_grad_input)
        self.assertEqual(value.grad.data, torch.ones(value.size()))

        # case when x is not same shape as y[1]
        x = Variable(torch.randn(1, 2), requires_grad=True)
        y = Variable(torch.zeros(10, 2))
        y[1] = x
        y.backward(torch.randn(10, 2))
        self.assertEqual(x.size(), x.grad.size())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号