def batch_norm(x, name_scope, training, epsilon=1e-3, decay=0.999):
"""Assume 2d [batch, values] tensor"""
with tf.variable_scope(name_scope):
size = x.get_shape().as_list()[1]
scale = tf.get_variable('scale', [size],
initializer=tf.constant_initializer(0.1))
offset = tf.get_variable('offset', [size])
pop_mean = tf.get_variable('pop_mean', [size],
initializer=tf.zeros_initializer(),
trainable=False)
pop_var = tf.get_variable('pop_var', [size],
initializer=tf.ones_initializer(),
trainable=False)
batch_mean, batch_var = tf.nn.moments(x, [0])
train_mean_op = tf.assign(
pop_mean,
pop_mean * decay + batch_mean * (1 - decay))
train_var_op = tf.assign(
pop_var,
pop_var * decay + batch_var * (1 - decay))
def batch_statistics():
with tf.control_dependencies([train_mean_op, train_var_op]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)
def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)
return tf.cond(training, batch_statistics, population_statistics)
bn_lstm.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录