def clip_gradient(model, clip):
"""Computes a gradient clipping coefficient based on gradient norm."""
totalnorm = 0
for p in model.parameters():
modulenorm = p.grad.data.norm()
totalnorm += modulenorm ** 2
totalnorm = math.sqrt(totalnorm)
return min(1, args.clip / (totalnorm + 1e-6))
评论列表
文章目录