base.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def _clip_grad_norms(self, gradients_to_variables, max_norm=5):
        """Clips the gradients by the given value.

        Args:
            gradients_to_variables: A list of gradient to variable pairs (tuples).
            max_norm: the maximum norm value.

        Returns:
            A list of clipped gradient to variable pairs.
         """
        grads_and_vars = []
        for grad, var in gradients_to_variables:
            if grad is not None:
                if isinstance(grad, tf.IndexedSlices):
                    tmp = tf.clip_by_norm(grad.values, max_norm)
                    grad = tf.IndexedSlices(
                        tmp, grad.indices, grad.dense_shape)
                else:
                    grad = tf.clip_by_norm(grad, max_norm)
            grads_and_vars.append((grad, var))
        return grads_and_vars
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号