def normalize(self, x, train=True):
"""
Returns a batch-normalized version of x.
"""
if train is not None:
mean, variance = tf.nn.moments(x, [0, 1, 2])
assign_mean = self.mean.assign(mean)
assign_variance = self.variance.assign(variance)
with tf.control_dependencies([assign_mean, assign_variance]):
return tf.nn.batch_norm_with_global_normalization(x, mean, variance, self.beta, self.gamma, self.epsilon, self.scale_after_norm)
else:
mean = self.ewma_trainer.average(self.mean)
variance = self.ewma_trainer.average(self.variance)
local_beta = tf.identity(self.beta)
local_gamma = tf.identity(self.gamma)
return tf.nn.batch_norm_with_global_normalization(x, mean, variance, local_beta, local_gamma, self.epsilon, self.scale_after_norm)
评论列表
文章目录