def __call__(self, x, train=True):
shape = x.get_shape().as_list()
with tf.variable_scope(self.name) as scope:
self.beta = tf.get_variable("beta", shape[1:],
initializer=tf.constant_initializer(0.))
self.gamma = tf.get_variable("gamma", shape[1:],
initializer=tf.random_normal_initializer(1.,0.02))
self.mean = tf.get_variable("mean", shape[1:],
initializer=tf.constant_initializer(0.),trainable=False)
self.variance = tf.get_variable("variance",shape[1:],
initializer=tf.constant_initializer(1.),trainable=False)
if train:
batch_mean, batch_var = tf.nn.moments(x, [0], name='moments')
self.mean.assign(batch_mean)
self.variance.assign(batch_var)
ema_apply_op = self.ema.apply([self.mean, self.variance])
with tf.control_dependencies([ema_apply_op]):
mean, var = tf.identity(batch_mean), tf.identity(batch_var)
else:
mean, var = self.ema.average(self.mean), self.ema.average(self.variance)
normed = tf.nn.batch_normalization(x, mean, var, self.beta, self.gamma, self.epsilon)
return normed
评论列表
文章目录