base.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:DMNN 作者: magnux 项目源码 文件源码
def batch_norm_layer_in_time(x,  max_length, step, is_training, epsilon=1e-3, decay=0.99, scope="layer"):
    '''Assume 2d [batch, values] 3d [batch, width, values] or 4d [batch, width, height, values] tensor'''
    with tf.variable_scope('bn_'+scope):
        dim_x = len(x.get_shape().as_list())
        size = x.get_shape().as_list()[dim_x-1]

        step_idcs = tf.range(step*size, (step+1)*size)

        scale_var = tf.get_variable('scale', [size * max_length], initializer=tf.constant_initializer(0.1))
        scale = tf.gather(scale_var, step_idcs)
        offset_var = tf.get_variable('offset', [size * max_length])
        offset = tf.gather(offset_var, step_idcs)

        pop_mean_var = tf.get_variable('pop_mean', [size * max_length], initializer=tf.zeros_initializer(), trainable=False)
        pop_mean = tf.gather(pop_mean_var, step_idcs)
        pop_var_var = tf.get_variable('pop_var', [size * max_length], initializer=tf.ones_initializer(), trainable=False)
        pop_var = tf.gather(pop_var_var, step_idcs)
        batch_mean, batch_var = tf.nn.moments(x, [i for i in range(dim_x-1)])

        train_mean_op = tf.scatter_update(pop_mean_var, step_idcs, pop_mean * decay + batch_mean * (1 - decay))
        train_var_op = tf.scatter_update(pop_var_var, step_idcs, pop_var * decay + batch_var * (1 - decay))

        def batch_statistics():
            with tf.control_dependencies([train_mean_op, train_var_op]):
                return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)

        def population_statistics():
            return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)

        if is_training:
            return batch_statistics()
        else:
            return population_statistics()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号