def batch_norm(x, phase_train):
"""
Batch normalization on convolutional maps.
Args:
x: Tensor, 4D BHWD input maps
n_out: integer, depth of input maps
phase_train: boolean tf.Variable, true indicates training phase
scope: string, variable scope
affn: whether to affn-transform outputs
Return:
normed: batch-normalized maps
Ref: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow/33950177
"""
name = 'batch_norm'
with tf.variable_scope(name):
phase_train = tf.convert_to_tensor(phase_train, dtype=tf.bool)
n_out = int(x.get_shape()[3])
beta = tf.Variable(tf.constant(0.0, shape=[n_out], dtype=x.dtype),
name=name+'/beta', trainable=True, dtype=x.dtype)
gamma = tf.Variable(tf.constant(1.0, shape=[n_out], dtype=x.dtype),
name=name+'/gamma', trainable=True, dtype=x.dtype)
batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.9)
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = control_flow_ops.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
return normed
评论列表
文章目录