def clip_and_debug_gradients(gradients, opts):
# extract just the gradients temporarily for global clipping and then rezip
if opts.gradient_clip is not None:
just_gradients, variables = zip(*gradients)
just_gradients, _ = tf.clip_by_global_norm(just_gradients, opts.gradient_clip)
gradients = zip(just_gradients, variables)
# verbose debugging
if opts.print_gradients:
for i, (gradient, variable) in enumerate(gradients):
if gradient is not None:
gradients[i] = (tf.Print(gradient, [l2_norm(gradient)],
"gradient %s l2_norm " % variable.name), variable)
# done
return gradients
评论列表
文章目录