def batch_norm(x,
is_training,
gamma=None,
beta=None,
axes=[0, 1, 2],
eps=1e-10,
name="bn_out",
decay=0.99,
dtype=tf.float32):
"""Applies batch normalization.
Collect mean and variances on x except the last dimension. And apply
normalization as below:
x_ = gamma * (x - mean) / sqrt(var + eps) + beta
Args:
x: Input tensor, [B, ...].
n_out: Integer, depth of input variable.
gamma: Scaling parameter.
beta: Bias parameter.
axes: Axes to collect statistics.
eps: Denominator bias.
Returns:
normed: Batch-normalized variable.
mean: Mean used for normalization (optional).
"""
n_out = x.get_shape()[-1]
try:
n_out = int(n_out)
shape = [n_out]
except:
shape = None
emean = tf.get_variable(
"ema_mean",
shape=shape,
trainable=False,
dtype=dtype,
initializer=tf.constant_initializer(
0.0, dtype=dtype))
evar = tf.get_variable(
"ema_var",
shape=shape,
trainable=False,
dtype=dtype,
initializer=tf.constant_initializer(
1.0, dtype=dtype))
if is_training:
mean, var = tf.nn.moments(x, axes, name="moments")
ema_mean_op = tf.assign_sub(emean, (emean - mean) * (1 - decay))
ema_var_op = tf.assign_sub(evar, (evar - var) * (1 - decay))
normed = tf.nn.batch_normalization(
x, mean, var, beta, gamma, eps, name=name)
return normed, [ema_mean_op, ema_var_op]
else:
normed = tf.nn.batch_normalization(
x, emean, evar, beta, gamma, eps, name=name)
return normed, None
评论列表
文章目录