def _batch_norm_without_layers(self, input_layer, decay, use_scale, epsilon):
"""Batch normalization on `input_layer` without tf.layers."""
# We make this function as similar as possible to the
# tf.contrib.layers.batch_norm, to minimize the differences between using
# layers and not using layers.
shape = input_layer.shape
num_channels = shape[3] if self.data_format == 'NHWC' else shape[1]
beta = self.get_variable('beta', [num_channels], tf.float32, tf.float32,
initializer=tf.zeros_initializer())
if use_scale:
gamma = self.get_variable('gamma', [num_channels], tf.float32,
tf.float32, initializer=tf.ones_initializer())
else:
gamma = tf.constant(1.0, tf.float32, [num_channels])
# For moving variables, we use tf.get_variable instead of self.get_variable,
# since self.get_variable returns the result of tf.cast which we cannot
# assign to.
moving_mean = tf.get_variable('moving_mean', [num_channels],
tf.float32,
initializer=tf.zeros_initializer(),
trainable=False)
moving_variance = tf.get_variable('moving_variance', [num_channels],
tf.float32,
initializer=tf.ones_initializer(),
trainable=False)
if self.phase_train:
bn, batch_mean, batch_variance = tf.nn.fused_batch_norm(
input_layer, gamma, beta, epsilon=epsilon,
data_format=self.data_format, is_training=True)
mean_update = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay=decay, zero_debias=False)
variance_update = moving_averages.assign_moving_average(
moving_variance, batch_variance, decay=decay, zero_debias=False)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, mean_update)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, variance_update)
else:
bn, _, _ = tf.nn.fused_batch_norm(
input_layer, gamma, beta, mean=moving_mean,
variance=moving_variance, epsilon=epsilon,
data_format=self.data_format, is_training=False)
return bn
评论列表
文章目录