def batch_norm(input_tensor, if_training):
"""
Batch normalization on convolutional feature maps.
Ref.: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
Args:
input_tensor: Tensor, 4D NHWC input feature maps
depth: Integer, depth of input feature maps
if_training: Boolean tf.Varialbe, true indicates training phase
scope: String, variable scope
Return:
normed_tensor: Batch-normalized feature maps
"""
with tf.variable_scope('batch_normalization'):
depth = int(input_tensor.get_shape()[-1])
beta = tf.Variable(tf.constant(0.0, shape=[depth]),
name='beta', trainable=True)
gamma = tf.Variable(tf.constant(1.0, shape=[depth]),
name='gamma', trainable=True)
batch_mean, batch_var = tf.nn.moments(input_tensor, [0,1,2], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.99)
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(if_training,
mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
normed_tensor = tf.nn.batch_normalization(input_tensor, mean, var, beta, gamma, 1e-3)
return normed_tensor
评论列表
文章目录