def _test_setitem_tensor(self, size, index):
x = Variable(torch.ones(*size), requires_grad=True)
y = x + 2
y_version = y._version
value = Variable(torch.Tensor(x[index].size()).fill_(7), requires_grad=True)
y[index] = value
self.assertNotEqual(y._version, y_version)
y.backward(torch.ones(*size))
expected_grad_input = torch.ones(*size)
if isinstance(index, Variable):
index = index.data
expected_grad_input[index] = 0
self.assertEqual(x.grad.data, expected_grad_input)
self.assertEqual(value.grad.data, torch.ones(value.size()))
# case when x is not same shape as y[1]
x = Variable(torch.randn(1, 2), requires_grad=True)
y = Variable(torch.zeros(10, 2))
y[1] = x
y.backward(torch.randn(10, 2))
self.assertEqual(x.size(), x.grad.size())
评论列表
文章目录