def bn(x, c):
x_shape = x.get_shape()
params_shape = x_shape[-1:]
axis = list(range(len(x_shape) - 1))
beta = tf.get_variable('beta',
shape=params_shape,
initializer=tf.zeros_initializer(),
dtype='float32',
collections=[tf.GraphKeys.GLOBAL_VARIABLES, GC_VARIABLES],
trainable=True)
gamma = tf.get_variable('gamma',
shape=params_shape,
initializer=tf.ones_initializer(),
dtype='float32',
collections=[tf.GraphKeys.GLOBAL_VARIABLES, GC_VARIABLES],
trainable=True)
moving_mean = tf.get_variable('moving_mean',
shape=params_shape,
initializer=tf.zeros_initializer(),
dtype='float32',
collections=[tf.GraphKeys.GLOBAL_VARIABLES, GC_VARIABLES],
trainable=False)
moving_variance = tf.get_variable('moving_variance',
shape=params_shape,
initializer=tf.ones_initializer(),
dtype='float32',
collections=[tf.GraphKeys.GLOBAL_VARIABLES, GC_VARIABLES],
trainable=False)
# These ops will only be performed when training.
mean, variance = tf.nn.moments(x, axis)
update_moving_mean = moving_averages.assign_moving_average(moving_mean,
mean, BN_DECAY)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, BN_DECAY)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
mean, variance = control_flow_ops.cond(
c['is_training'], lambda: (mean, variance),
lambda: (moving_mean, moving_variance))
x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
return x
# resnet block
评论列表
文章目录