def __call__(self, x):
shape = x.get_shape()
shp = self.in_dim or shape[-1]
with tf.variable_scope(self.name) as scope:
self.gamma = tf.get_variable("gamma", [shp],
initializer=tf.random_normal_initializer(1., 0.02))
self.beta = tf.get_variable("beta", [shp],
initializer=tf.constant_initializer(0.))
self.mean, self.variance = tf.nn.moments(x, [0, 1, 2])
self.mean.set_shape((shp,))
self.variance.set_shape((shp,))
self.ema_apply_op = self.ema.apply([self.mean, self.variance])
if self.train:
# with tf.control_dependencies([self.ema_apply_op]):
normalized_x = tf.nn.batch_norm_with_global_normalization(
x, self.mean, self.variance, self.beta, self.gamma, self.epsilon,
scale_after_normalization=True)
else:
normalized_x = tf.nn.batch_norm_with_global_normalization(
x, self.ema.average(self.mean), self.ema.average(self.variance), self.beta,
self.gamma, self.epsilon,
scale_after_normalization=True)
return normalized_x
评论列表
文章目录