def __call__(self, opt):
if cuda.available:
kernel = cuda.elementwise(
'T p, T decay', 'T g', 'g += decay * p', 'weight_decay')
rate = self.rate
for name, param in opt.target.namedparams():
if name == 'b' or name.endswith('/b'):
continue
p, g = param.data, param.grad
with cuda.get_device(p) as dev:
if int(dev) == -1:
g += rate * p
else:
kernel(p, rate, g)
评论列表
文章目录