def test_L1Penalty(self):
weight = 1
m = nn.L1Penalty(weight, False, False)
input = torch.rand(2, 10).add_(-0.5)
input[0][0] = 0
m.forward(input)
grad = m.backward(input, torch.ones(input.size()))
self.assertEqual(input.abs().sum() * weight, m.loss)
true_grad = (input.gt(0).type_as(grad) +
input.lt(0).type_as(grad).mul_(-1)).mul_(weight)
self.assertEqual(true_grad, grad)
# Check that these don't raise errors
m.__repr__()
str(m)
评论列表
文章目录