def __call__(self, opt):
if cuda.available:
kernel = cuda.elementwise(
'T low, T high',
'T p',
'p = (p < low) ? low : (p > high) ? high : p',
'weight_clip')
for param in opt.target.params():
p = param.data
with cuda.get_device(p) as dev:
if int(dev) == -1:
numpy.clip(p, self.low, self.high)
else:
kernel(self.low, self.high, p)
评论列表
文章目录