trainer.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:RNNVis 作者: myaooo 项目源码 文件源码
def get_gradient_clipper(clipper, *args, **kwargs):
    """
    Simple helper to get Gradient Clipper
    E.g: clipper = get_gradient_clipper('value', value_min, value_max, name='ValueClip')
    :param clipper: a string denoting TF Gradient Clipper (e.g. "global_norm", denote tf.clip_by_global_norm)
        or a function of type f(tensor) -> clipped_tensor
    :param args: used to create the clipper
    :param kwargs: used to create the clipper
    :return: a function (tensor) -> (clipped tensor)
    """
    if callable(clipper):
        return clipper
    # workaround of global_norm clipper, since it returns two variable with the second one as a scalar tensor
    if clipper == 'global_norm':
        return lambda t_list: tf.clip_by_global_norm(t_list, *args, **kwargs)[0]
    if clipper in _str2clipper:
        clipper = _str2clipper[clipper]
    else:
        raise ValueError('clipper should be a callable function or a given key in _str2clipper!')
    return lambda t_list: [clipper(t, *args, **kwargs) for t in t_list]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号