bingrad_common.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:terngrad 作者: wenwei202 项目源码 文件源码
def gradient_binarizing_scalers(grads_and_vars, clip_factor):
    """ Get the scalers."""
    gradients, variables = zip(*grads_and_vars)
    scalers = []
    for gradient in gradients:
        if gradient is None:
            scalers.append(None)
            continue

        if(clip_factor > 1.0e-5):
            mean_gradient = tf.reduce_mean(gradient)
            stddev_gradient = tf.sqrt(tf.reduce_mean(tf.square(gradient - mean_gradient)))
            scalers.append(clip_factor * stddev_gradient)
        else:
            scalers.append(tf.reduce_max(tf.abs(gradient)))

    return list(zip(scalers, variables))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号