nn.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:tf_practice 作者: juho-lee 项目源码 文件源码
def batch_norm(input, is_train, scope=None, reuse=None, decay=0.9):
    shape = input.get_shape()
    num_out = shape[-1]

    with tf.variable_op_scope([input], scope, 'BN', reuse=reuse):
        beta = tf.get_variable('beta', [num_out],
                initializer=tf.constant_initializer(0.0),
                trainable=True)
        gamma = tf.get_variable('gamma', [num_out],
                initializer=tf.constant_initializer(1.0),
                trainable=True)

        batch_mean, batch_var = tf.nn.moments(input, [0,1,2], name='moments') \
                if len(shape)==4 else tf.nn.moments(input, [0], name='moments')
        ema = tf.train.ExponentialMovingAverage(decay=decay)

        def mean_var_with_update():
            ema_apply_op = ema.apply([batch_mean, batch_var])
            with tf.control_dependencies([ema_apply_op]):
                return tf.identity(batch_mean), tf.identity(batch_var)

        mean, var = tf.cond(is_train,
                mean_var_with_update,
                lambda: (ema.average(batch_mean), ema.average(batch_var)))
        return tf.nn.batch_normalization(input, mean, var, beta, gamma, 1e-3)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号