def check_gradient_scaling(self):
w = self.target.param.array
g = self.target.param.grad
rate = 0.2
expect = w - g * rate
opt = optimizers.SGD(lr=1)
opt.setup(self.target)
opt.add_hook(GradientScaling(rate))
opt.update()
testing.assert_allclose(expect, w)
评论列表
文章目录