def batch_normalization(self, input, name,
scale_offset=True,
relu=False,
decay=0.999,
moving_vars='moving_vars'):
# NOTE: Currently, only inference is supported
with tf.variable_scope(name):
axis = list(range(len(input.get_shape()) - 1))
shape = [input.get_shape()[-1]]
if scale_offset:
scale = self.make_var('scale', shape=shape,
initializer=tf.ones_initializer(),
trainable=self.trainable)
offset = self.make_var('offset', shape=shape,
initializer=tf.zeros_initializer(),
trainable=self.trainable)
else:
scale, offset = (None, None)
# Create moving_mean and moving_variance add them to
# GraphKeys.MOVING_AVERAGE_VARIABLES collections.
moving_collections = [moving_vars, tf.GraphKeys.MOVING_AVERAGE_VARIABLES]
moving_mean = self.make_var('mean',
shape,
initializer=tf.zeros_initializer(),
trainable=False,
collections=moving_collections)
moving_variance = self.make_var('variance',
shape,
initializer=tf.ones_initializer(),
trainable=False,
collections=moving_collections)
if self.trainable:
# Calculate the moments based on the individual batch.
mean, variance = tf.nn.moments(input, axis)
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, decay)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
else:
# Just use the moving_mean and moving_variance.
mean = moving_mean
variance = moving_variance
output = tf.nn.batch_normalization(
input,
mean=mean,
variance=variance,
offset=offset,
scale=scale,
# TODO: This is the default Caffe batch norm eps
# Get the actual eps from parameters
variance_epsilon=1e-5,
name=name)
if relu:
output = tf.nn.relu(output)
return output
评论列表
文章目录