def batch_norm(x, name="batch_norm"):
eps = 1e-6
with tf.variable_scope(name):
nchannels = x.get_shape()[3]
scale = tf.get_variable("scale", [nchannels], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))
center = tf.get_variable("center", [nchannels], initializer=tf.constant_initializer(0.0, dtype = tf.float32))
ave, dev = tf.nn.moments(x, axes=[1,2], keep_dims=True)
inv_dev = tf.rsqrt(dev + eps)
normalized = (x-ave)*inv_dev * scale + center
return normalized
评论列表
文章目录