nn.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:icml17_knn 作者: taolei87 项目源码 文件源码
def batch_normalization_with_mask(x, mask, scope, decay=0.999, eps=1e-6, training=True):
    ndim = len(x.get_shape().as_list())
    fdim = x.get_shape().as_list()[-1]
    with tf.variable_scope(scope):
        gamma = tf.get_variable("scale", [fdim], tf.float32, tf.constant_initializer(1.0))
        beta = tf.get_variable("offset", [fdim], tf.float32, tf.constant_initializer(0.0))
        mean = tf.get_variable("mean", [fdim], tf.float32, tf.constant_initializer(0.0), trainable=False)
        var = tf.get_variable("variance", [fdim], tf.float32, tf.constant_initializer(1.0), trainable=False)
        if training:
            x_mean, x_var = tf.nn.weighted_moments(x, range(ndim - 1), mask)
            avg_mean = tf.assign(mean, mean * decay + x_mean * (1.0 - decay))
            avg_var = tf.assign(var, var * decay + x_var * (1.0 - decay))
            with tf.control_dependencies([avg_mean, avg_var]):
                return tf.nn.batch_normalization(x, x_mean, x_var, beta, gamma, eps)
        else:
            return tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号