def batch_norm(data, name):
shape_param = data.get_shape()[-1]
beta = tf.get_variable(name=name+'_beta', shape=shape_param, dtype=tf.float32,
initializer=tf.constant_initializer(0.0, tf.float32))
gamma = tf.get_variable(name=name+'_gamma', shape=shape_param, dtype=tf.float32,
initializer=tf.constant_initializer(1.0, tf.float32))
if FLAGS.train_mode:
mean_param, variance_param = tf.nn.moments(x=data, axes=[0, 1, 2], name=name+'_moments')
moving_mean = tf.get_variable(name=name+'_moving_mean', shape=shape_param, dtype=tf.float32,
initializer=tf.constant_initializer(0.0, tf.float32), trainable=False)
moving_variance = tf.get_variable(name=name+'_moving_variance', shape=shape_param, dtype=tf.float32,
initializer=tf.constant_initializer(1.0, tf.float32), trainable=False)
mean = moving_averages.assign_moving_average(variable=moving_mean, value=mean_param, decay=0.9)
variance = moving_averages.assign_moving_average(variable=moving_variance, value=variance_param, decay=0.9)
else:
mean = tf.get_variable(name=name+'_moving_mean', shape=shape_param, dtype=tf.float32,
initializer=tf.constant_initializer(0.0, tf.float32), trainable=False)
variance = tf.get_variable(name=name+'_moving_variance', shape=shape_param, dtype=tf.float32,
initializer=tf.constant_initializer(1.0, tf.float32), trainable=False)
tf.summary.scalar(mean.op.name, mean)
tf.summary.scalar(variance.op.name, variance)
b_norm = tf.nn.batch_normalization(x=data, mean=mean, variance=variance,
offset=beta, scale=gamma, variance_epsilon=0.001, name=name)
return b_norm
评论列表
文章目录