def batchNormalization(x, is_training= True, decay= 0.9, epsilon= 0.001):
x_shape = x.get_shape()
params_shape = x_shape[-1:]
axis = list(range(len(x_shape) - 1))
beta = _get_variable('beta',
params_shape,
initializer= tf.zeros_initializer)
gamma = _get_variable('gamma',
params_shape,
initializer= tf.ones_initializer)
moving_mean = _get_variable('moving_mean',
params_shape,
initializer= tf.zeros_initializer,
trainable= False)
moving_variance = _get_variable('moving_variance',
params_shape,
initializer= tf.ones_initializer,
trainable= False)
# These ops will only be preformed when training.
if is_training:
mean, variance = tf.nn.moments(x, axis)
update_moving_mean = moving_averages.assign_moving_average(moving_mean,
mean, decay)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS , update_moving_mean)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS , update_moving_variance)
return tf.nn.batch_normalization(x, mean, variance, beta, gamma, epsilon)
else:
return tf.nn.batch_normalization(x, moving_mean, moving_variance, beta, gamma, epsilon)
评论列表
文章目录