training.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def clip_grad_global_norms(tvars, loss, opt, global_norm=1, gate_gradients=1, gradient_noise_scale=4.0, GATE_GRAPH=2, grad_loss=None, agre_method=None, col_grad_ops=False):
    """Clips the gradients by the given value.

    Args:
        tvars: trainable variables used for gradint updates
        loss: total loss of the network
        opt: optimizer
        global_norm: the maximum global norm

    Returns:
        A list of clipped gradient to variable pairs.
     """
    var_refs = [v.ref() for v in tvars]
    grads = tf.gradients(loss, var_refs, grad_ys=grad_loss, gate_gradients=(
        gate_gradients == 1), aggregation_method=agre_method, colocate_gradients_with_ops=col_grad_ops)
    if gradient_noise_scale > 1:
        grads = add_scaled_noise_to_gradients(
            list(zip(grads, tvars)), gradient_noise_scale=gradient_noise_scale)
    if gate_gradients == GATE_GRAPH:
        grads = tf.tuple(grads)
    grads, _ = tf.clip_by_global_norm(grads, global_norm)
    grads_and_vars = list(zip(grads, tvars))
    return grads_and_vars
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号