def _bn(x, is_training, hypes):
x_shape = x.get_shape()
params_shape = x_shape[-1:]
axis = list(range(len(x_shape) - 1))
beta = _get_variable('beta',
params_shape,
initializer=tf.zeros_initializer())
gamma = _get_variable('gamma',
params_shape,
initializer=tf.ones_initializer())
moving_mean = _get_variable('moving_mean',
params_shape,
initializer=tf.zeros_initializer(),
trainable=False)
moving_variance = _get_variable('moving_variance',
params_shape,
initializer=tf.ones_initializer(),
trainable=False)
# These ops will only be preformed when training.
mean, variance = tf.nn.moments(x, axis)
update_moving_mean = moving_averages.assign_moving_average(moving_mean,
mean,
BN_DECAY)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, BN_DECAY)
if hypes['use_moving_average_bn']:
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
mean, variance = control_flow_ops.cond(
is_training, lambda: (mean, variance),
lambda: (moving_mean, moving_variance))
else:
mean, variance = mean, variance
x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
# x.set_shape(inputs.get_shape()) ??
return x
评论列表
文章目录