bn_lstm.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:hierarchical-attention-networks 作者: ematvey 项目源码 文件源码
def batch_norm(x, name_scope, training, epsilon=1e-3, decay=0.999):
    """Assume 2d [batch, values] tensor"""

    with tf.variable_scope(name_scope):
        size = x.get_shape().as_list()[1]

        scale = tf.get_variable('scale', [size],
            initializer=tf.constant_initializer(0.1))
        offset = tf.get_variable('offset', [size])

        pop_mean = tf.get_variable('pop_mean', [size],
            initializer=tf.zeros_initializer(),
            trainable=False)
        pop_var = tf.get_variable('pop_var', [size],
            initializer=tf.ones_initializer(),
            trainable=False)
        batch_mean, batch_var = tf.nn.moments(x, [0])

        train_mean_op = tf.assign(
            pop_mean,
            pop_mean * decay + batch_mean * (1 - decay))
        train_var_op = tf.assign(
            pop_var,
            pop_var * decay + batch_var * (1 - decay))

        def batch_statistics():
            with tf.control_dependencies([train_mean_op, train_var_op]):
                return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)

        def population_statistics():
            return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)

        return tf.cond(training, batch_statistics, population_statistics)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号