def batch_norm_layer_in_time(x, max_length, step, is_training, epsilon=1e-3, decay=0.99, scope="layer"):
'''Assume 2d [batch, values] 3d [batch, width, values] or 4d [batch, width, height, values] tensor'''
with tf.variable_scope('bn_'+scope):
dim_x = len(x.get_shape().as_list())
size = x.get_shape().as_list()[dim_x-1]
step_idcs = tf.range(step*size, (step+1)*size)
scale_var = tf.get_variable('scale', [size * max_length], initializer=tf.constant_initializer(0.1))
scale = tf.gather(scale_var, step_idcs)
offset_var = tf.get_variable('offset', [size * max_length])
offset = tf.gather(offset_var, step_idcs)
pop_mean_var = tf.get_variable('pop_mean', [size * max_length], initializer=tf.zeros_initializer(), trainable=False)
pop_mean = tf.gather(pop_mean_var, step_idcs)
pop_var_var = tf.get_variable('pop_var', [size * max_length], initializer=tf.ones_initializer(), trainable=False)
pop_var = tf.gather(pop_var_var, step_idcs)
batch_mean, batch_var = tf.nn.moments(x, [i for i in range(dim_x-1)])
train_mean_op = tf.scatter_update(pop_mean_var, step_idcs, pop_mean * decay + batch_mean * (1 - decay))
train_var_op = tf.scatter_update(pop_var_var, step_idcs, pop_var * decay + batch_var * (1 - decay))
def batch_statistics():
with tf.control_dependencies([train_mean_op, train_var_op]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)
def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)
if is_training:
return batch_statistics()
else:
return population_statistics()
评论列表
文章目录