def __call__(self, opt):
if cuda.available:
kernel = cuda.elementwise(
'T s, T decay', 'T g', 'g += decay * s', 'lasso')
rate = self.rate
for param in opt.target.params():
p, g = param.data, param.grad
xp = cuda.get_array_module(p)
sign = xp.sign(p)
with cuda.get_device(p) as dev:
if int(dev) == -1:
g += rate * sign
else:
kernel(sign, rate, g)
评论列表
文章目录