def batch_norm_new(name, input_var, is_train, decay=0.999, epsilon=1e-5):
"""Batch normalization modified from BatchNormLayer in Tensorlayer.
Source: <https://github.com/zsdonghao/tensorlayer/blob/master/tensorlayer/layers.py#L2190>
"""
inputs_shape = input_var.get_shape()
axis = list(range(len(inputs_shape) - 1))
params_shape = inputs_shape[-1:]
with tf.variable_scope(name) as scope:
# Trainable beta and gamma variables
beta = tf.get_variable('beta',
shape=params_shape,
initializer=tf.zeros_initializer)
gamma = tf.get_variable('gamma',
shape=params_shape,
initializer=tf.random_normal_initializer(mean=1.0, stddev=0.002))
# Moving mean and variance updated during training
moving_mean = tf.get_variable('moving_mean',
params_shape,
initializer=tf.zeros_initializer,
trainable=False)
moving_variance = tf.get_variable('moving_variance',
params_shape,
initializer=tf.constant_initializer(1.),
trainable=False)
# Compute mean and variance along axis
batch_mean, batch_variance = tf.nn.moments(input_var, axis, name='moments')
# Define ops to update moving_mean and moving_variance
update_moving_mean = moving_averages.assign_moving_average(moving_mean, batch_mean, decay, zero_debias=False)
update_moving_variance = moving_averages.assign_moving_average(moving_variance, batch_variance, decay, zero_debias=False)
# Define a function that :
# 1. Update moving_mean & moving_variance with batch_mean & batch_variance
# 2. Then return the batch_mean & batch_variance
def mean_var_with_update():
with tf.control_dependencies([update_moving_mean, update_moving_variance]):
return tf.identity(batch_mean), tf.identity(batch_variance)
# Perform different ops for training and testing
if is_train:
mean, variance = mean_var_with_update()
normed = tf.nn.batch_normalization(input_var, mean, variance, beta, gamma, epsilon)
else:
normed = tf.nn.batch_normalization(input_var, moving_mean, moving_variance, beta, gamma, epsilon)
# mean, variance = tf.cond(
# is_train,
# mean_var_with_update, # Training
# lambda: (moving_mean, moving_variance) # Testing - it will use the moving_mean and moving_variance (fixed during test) that are computed during training
# )
# normed = tf.nn.batch_normalization(input_var, mean, variance, beta, gamma, epsilon)
return normed
评论列表
文章目录