bingrad_common.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:terngrad 作者: wenwei202 项目源码 文件源码
def clip_gradients_by_stddev(grads_and_vars, clip_factor = 2.5):
    """ Clip gradients to [-clip_factor*stddev, clip_factor*stddev]."""
    gradients, variables = zip(*grads_and_vars)
    clipped_gradients = []
    for gradient in gradients:
        if gradient is None:
            clipped_gradients.append(None)
            continue

        mean_gradient = tf.reduce_mean(gradient)
        stddev_gradient = tf.sqrt(tf.reduce_mean(tf.square(gradient - mean_gradient)))
        #clipped_gradient = tf.clip_by_value(gradient, -clip_factor * stddev_gradient, clip_factor * stddev_gradient)
        clipped_gradient = tf.cond(tf.size(gradient) < FLAGS.size_to_binarize,
                               lambda: gradient,
                               lambda: tf.clip_by_value(gradient, -clip_factor * stddev_gradient, clip_factor * stddev_gradient))

        clipped_gradients.append(clipped_gradient)
    return list(zip(clipped_gradients, variables))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号