def batchnorm(x, name, phase, updates, gamma=0.96):
k = x.get_shape()[1]
runningmean = tf.get_variable(name+"/mean", shape=[1, k], initializer=tf.constant_initializer(0.0), trainable=False)
runningvar = tf.get_variable(name+"/var", shape=[1, k], initializer=tf.constant_initializer(1e-4), trainable=False)
testy = (x - runningmean) / tf.sqrt(runningvar)
mean_ = mean(x, axis=0, keepdims=True)
var_ = mean(tf.square(x), axis=0, keepdims=True)
std = tf.sqrt(var_)
trainy = (x - mean_) / std
updates.extend([
tf.assign(runningmean, runningmean * gamma + mean_ * (1 - gamma)),
tf.assign(runningvar, runningvar * gamma + var_ * (1 - gamma))
])
y = switch(phase, trainy, testy)
out = y * tf.get_variable(name+"/scaling", shape=[1, k], initializer=tf.constant_initializer(1.0), trainable=True)\
+ tf.get_variable(name+"/translation", shape=[1,k], initializer=tf.constant_initializer(0.0), trainable=True)
return out
# ================================================================
# Mathematical utils
# ================================================================
评论列表
文章目录