def check_lasso(self):
w = self.target.param.data
g = self.target.param.grad
xp = cuda.get_array_module(w)
decay = 0.2
expect = w - g - decay * xp.sign(w)
opt = optimizers.SGD(lr=1)
opt.setup(self.target)
opt.add_hook(optimizer.Lasso(decay))
opt.update()
gradient_check.assert_allclose(expect, w)
评论列表
文章目录