def _batch_norm(x, name, is_train):
""" Apply a batch normalization layer. """
with tf.variable_scope(name):
inputs_shape = x.get_shape()
axis = list(range(len(inputs_shape) - 1))
param_shape = int(inputs_shape[-1])
moving_mean = tf.get_variable('mean', [param_shape], initializer=tf.constant_initializer(0.0), trainable=False)
moving_var = tf.get_variable('variance', [param_shape], initializer=tf.constant_initializer(1.0), trainable=False)
beta = tf.get_variable('offset', [param_shape], initializer=tf.constant_initializer(0.0))
gamma = tf.get_variable('scale', [param_shape], initializer=tf.constant_initializer(1.0))
control_inputs = []
def mean_var_with_update():
mean, var = tf.nn.moments(x, axis)
update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, 0.99)
update_moving_var = moving_averages.assign_moving_average(moving_var, var, 0.99)
control_inputs = [update_moving_mean, update_moving_var]
return tf.identity(mean), tf.identity(var)
def mean_var():
mean = moving_mean
var = moving_var
return tf.identity(mean), tf.identity(var)
mean, var = tf.cond(is_train, mean_var_with_update, mean_var)
with tf.control_dependencies(control_inputs):
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
return normed
评论列表
文章目录