def test_sparse_backward(self):
class FixedGradientFunction(Function):
def __init__(self, grad):
self.grad = grad
def forward(self, x):
return x
def backward(self, grad_x):
return self.grad
size = torch.Size([6, 3, 2])
i1 = torch.LongTensor([
[0, 3, 4],
[0, 2, 2],
])
v1 = torch.DoubleTensor([[1, 2], [4, 5], [7, 8]])
sparse_grad1 = torch.sparse.DoubleTensor(i1, v1, size)
i2 = torch.LongTensor([
[0, 1, 3, 4],
[0, 1, 2, 2],
])
v2 = torch.DoubleTensor([[1, 2], [4, 3], [4, 5], [7, 8]])
sparse_grad2 = torch.sparse.DoubleTensor(i2, v2, size)
dense_grad = torch.rand(size).double()
sparse_fn1 = FixedGradientFunction(sparse_grad1)
sparse_fn2 = FixedGradientFunction(sparse_grad2)
dense_fn = FixedGradientFunction(dense_grad)
# sparse first
x = Variable(torch.randn(5, 5), requires_grad=True)
(sparse_fn1(x) + dense_fn(x) + sparse_fn2(x)).sum().backward()
self.assertEqual(x.grad.data, dense_grad + sparse_grad1 + sparse_grad2)
# dense first
x = Variable(torch.randn(5, 5), requires_grad=True)
(dense_fn(x) + sparse_fn1(x) + sparse_fn2(x)).sum().backward()
self.assertEqual(x.grad.data, dense_grad + sparse_grad1 + sparse_grad2)
# sparse only
x = Variable(torch.randn(5, 5), requires_grad=True)
(sparse_fn1(x) + sparse_fn2(x)).sum().backward()
self.assertEqual(x.grad.data, sparse_grad1 + sparse_grad2)
评论列表
文章目录