optimizer.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def clip_gradients_by_global_norm(gradients_variables, clip_norm=20.):
    """Clips gradients of a multitask loss by their global norm.

    Ignores all-zero tensors when computing the global norm.

    Args:
      gradients_variables: a list of pairs (gradient, variable).
      clip_norm: a float Tensor, the global norm to clip on. Default is 20.0.

    Returns:
      list: A list of pairs of the same type as gradients_variables,.
      fixed_global_norm: A 0-D (scalar) Tensor representing the global norm.
    """
    gradients, variables = six.moves.zip(*gradients_variables)

    def _replace_nonexisting_grad(grad):
        if grad is None:
            return grad
        all_zeros = _is_all_zeros(grad)
        return tf.cond(
            all_zeros,
            lambda: tf.zeros([], dtype=tf.as_dtype(grad.dtype)),
            lambda: grad)

    nonzero_gradients = [_replace_nonexisting_grad(g) for g in gradients]
    fixed_global_norm = tf.global_norm(nonzero_gradients)
    gradients, _ = tf.clip_by_global_norm(
        gradients, clip_norm, use_norm=fixed_global_norm)
    return list(six.moves.zip(gradients, variables)), fixed_global_norm
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号