def gradient_clipping(gradients, max_norm=5.0):
global_grad_norm = tensor.sqrt(sum(map(lambda x: tensor.sqr(x).sum(), gradients)))
multiplier = tensor.switch(global_grad_norm < max_norm, 1.0, max_norm / global_grad_norm)
return [g * multiplier for g in gradients]
评论列表
文章目录