def check_weight_decay(self):
w = self.target.param.data
g = self.target.param.grad
decay = 0.2
expect = w - g - decay * w
opt = optimizers.SGD(lr=1)
opt.setup(self.target)
opt.add_hook(optimizer.WeightDecay(decay))
opt.update()
gradient_check.assert_allclose(expect, w)
评论列表
文章目录