def get_gradient_clipper(clipper, *args, **kwargs):
"""
Simple helper to get Gradient Clipper
E.g: clipper = get_gradient_clipper('value', value_min, value_max, name='ValueClip')
:param clipper: a string denoting TF Gradient Clipper (e.g. "global_norm", denote tf.clip_by_global_norm)
or a function of type f(tensor) -> clipped_tensor
:param args: used to create the clipper
:param kwargs: used to create the clipper
:return: a function (tensor) -> (clipped tensor)
"""
if callable(clipper):
return clipper
# workaround of global_norm clipper, since it returns two variable with the second one as a scalar tensor
if clipper == 'global_norm':
return lambda t_list: tf.clip_by_global_norm(t_list, *args, **kwargs)[0]
if clipper in _str2clipper:
clipper = _str2clipper[clipper]
else:
raise ValueError('clipper should be a callable function or a given key in _str2clipper!')
return lambda t_list: [clipper(t, *args, **kwargs) for t in t_list]
评论列表
文章目录