base.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def _clip_grad_global_norms(self, tvars, loss, opt, global_norm=8, gate_gradients=1, gradient_noise_scale=None, GATE_GRAPH=2, grad_loss=None, agre_method=None, col_grad_ops=False):
        """Clips the gradients by the given value.

        Args:
            tvars: trainable variables used for gradint updates
            loss: total loss of the network
            opt: optimizer
            global_norm: the maximum global norm

        Returns:
            A list of clipped gradient to variable pairs.
         """
        var_refs = [v.read_value() for v in tvars]
        grads = tf.gradients(loss, var_refs, grad_ys=grad_loss, gate_gradients=(
            gate_gradients == 1), aggregation_method=agre_method, colocate_gradients_with_ops=col_grad_ops)
        if gradient_noise_scale is not None:
            grads = self._add_scaled_noise_to_gradients(
                list(zip(grads, tvars)), gradient_noise_scale=gradient_noise_scale)
        if gate_gradients == GATE_GRAPH:
            grads = tf.tuple(grads)
        grads, _ = tf.clip_by_global_norm(grads, global_norm)
        grads_and_vars = list(zip(grads, tvars))
        return grads_and_vars
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号