def bn(x, c):
x_shape = x.get_shape()
params_shape = x_shape[-1:]
axis = list(range(len(x_shape) - 1))
beta = _get_variable('beta',
params_shape,
initializer=tf.zeros_initializer())
#tf.constant_initializer(0.00, dtype='float')
gamma = _get_variable('gamma',
params_shape,
initializer=tf.ones_initializer())
moving_mean = _get_variable('moving_mean',
params_shape,
initializer=tf.zeros_initializer(),
trainable=False)
moving_variance = _get_variable('moving_variance',
params_shape,
initializer=tf.ones_initializer(),
trainable=False)
# These ops will only be performed when training.
mean, variance = tf.nn.moments(x, axis)
update_moving_mean = moving_averages.assign_moving_average(moving_mean,
mean, BN_DECAY)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, BN_DECAY)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
mean, variance = control_flow_ops.cond(
c['is_training'], lambda: (mean, variance),
lambda: (moving_mean, moving_variance))
x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
return x
# wrapper for get_variable op
评论列表
文章目录