def test_basic_op_grad_fallback(self):
"""Grad output might need to be reshaped to match the second argument."""
x = Variable(torch.randn(4, 6), requires_grad=True)
b = Variable(torch.rand(12, 1) + 1e-2, requires_grad=True)
c = Variable(torch.rand(8, 1) + 1e-2, requires_grad=True)
def y():
# .mm() depends on the grad_output being of correct size
return b.mm(Variable(torch.rand(1, 2) + 1e-2))
def z():
return c.mm(Variable(torch.rand(1, 3) + 1e-2))
# suppress broadcastable warning
with warnings.catch_warnings(record=True):
(x + y()).sum().backward()
(x - y()).sum().backward()
(x * y()).sum().backward()
(x / y()).sum().backward()
(x.dist(y())).sum().backward()
(x.lerp(y(), 0.5)).sum().backward()
(x.max(y())).sum().backward()
(x.min(y())).sum().backward()
(x.masked_fill(y() < 0, 0.5)).sum().backward()
(x.masked_scatter(Variable(y().data < 0.25), z())).sum().backward()
(x.masked_select(Variable(y().data < 0.25))).sum().backward()
(x.addcmul(1, y(), z())).sum().backward()
(x.addcdiv(1, y(), z())).sum().backward()
(x.abs() ** y()).sum().backward()
评论列表
文章目录