def _bn(self, x, params_init, is_training):
x_shape = x.get_shape()
axis = list(range(len(x_shape) - 1))
beta = self._get_variable_const('beta', initializer=tf.constant(params_init['bias']))
gamma = self._get_variable_const('gamma', initializer=tf.constant(params_init['weight']))
moving_mean = self._get_variable_const('moving_mean',
initializer=tf.constant(params_init['running_mean']), trainable=False)
moving_variance = self._get_variable_const('moving_variance',
initializer=tf.constant(params_init['running_var']), trainable=False)
# mean, variance = tf.nn.moments(x, axis)
# update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY)
# update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY)
# tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
# tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
#
# if ~is_training:
# mean = moving_mean
# variance = moving_variance
# else:
# ema = tf.train.ExponentialMovingAverage(decay=BN_DECAY)
#
# def mean_var_with_update():
# ema_apply_op = ema.apply([mean, variance])
# with tf.control_dependencies([ema_apply_op]):
# return tf.identity(mean), tf.identity(variance)
# mean, variance = mean_var_with_update()
# mean, variance = control_flow_ops.cond(is_training, lambda: (mean, variance),
# lambda: (moving_mean, moving_variance))
# x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
x = tf.layers.batch_normalization(x, momentum=BN_DECAY, epsilon=BN_EPSILON, beta_initializer=tf.constant_initializer(params_init['bias']),
gamma_initializer=tf.constant_initializer(params_init['weight']),
moving_mean_initializer=tf.constant_initializer(params_init['running_mean']),
moving_variance_initializer=tf.constant_initializer(params_init['running_var']),
training=is_training)
return x
评论列表
文章目录