def apply_batch_norm(input_tensor, config, i):
with tf.variable_scope("batch_norm") as scope:
if i != 0 :
# Do not create extra variables for each time step
scope.reuse_variables()
# Mean and variance normalisation simply crunched over all axes
axes = list(range(len(input_tensor.get_shape())))
mean, variance = tf.nn.moments(input_tensor, axes=axes, shift=None, name=None, keep_dims=False)
stdev = tf.sqrt(variance + 0.001)
# Rescaling
bn = input_tensor - mean
bn /= stdev
# Learnable extra rescaling
# tf.get_variable("relu_fc_weights", initializer=tf.random_normal(mean=0.0, stddev=0.0)
bn *= tf.get_variable("a_noreg", initializer=tf.random_normal([1], mean=0.5, stddev=0.0))
bn += tf.get_variable("b_noreg", initializer=tf.random_normal([1], mean=0.0, stddev=0.0))
# bn *= tf.Variable(0.5, name=(scope.name + "/a_noreg"))
# bn += tf.Variable(0.0, name=(scope.name + "/b_noreg"))
return bn
residual_lstm_model_MNIST_dataset.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录