test_legacy_nn.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def test_L1Penalty(self):
        weight = 1
        m = nn.L1Penalty(weight, False, False)

        input = torch.rand(2, 10).add_(-0.5)
        input[0][0] = 0

        m.forward(input)
        grad = m.backward(input, torch.ones(input.size()))

        self.assertEqual(input.abs().sum() * weight, m.loss)

        true_grad = (input.gt(0).type_as(grad) +
                     input.lt(0).type_as(grad).mul_(-1)).mul_(weight)
        self.assertEqual(true_grad, grad)

        # Check that these don't raise errors
        m.__repr__()
        str(m)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号