def batch_normalization_with_mask(x, mask, scope, decay=0.999, eps=1e-6, training=True):
ndim = len(x.get_shape().as_list())
fdim = x.get_shape().as_list()[-1]
with tf.variable_scope(scope):
gamma = tf.get_variable("scale", [fdim], tf.float32, tf.constant_initializer(1.0))
beta = tf.get_variable("offset", [fdim], tf.float32, tf.constant_initializer(0.0))
mean = tf.get_variable("mean", [fdim], tf.float32, tf.constant_initializer(0.0), trainable=False)
var = tf.get_variable("variance", [fdim], tf.float32, tf.constant_initializer(1.0), trainable=False)
if training:
x_mean, x_var = tf.nn.weighted_moments(x, range(ndim - 1), mask)
avg_mean = tf.assign(mean, mean * decay + x_mean * (1.0 - decay))
avg_var = tf.assign(var, var * decay + x_var * (1.0 - decay))
with tf.control_dependencies([avg_mean, avg_var]):
return tf.nn.batch_normalization(x, x_mean, x_var, beta, gamma, eps)
else:
return tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
评论列表
文章目录