def _norm(input, is_train, reuse=True, norm=None):
assert norm in ['instance', 'batch', None]
if norm == 'instance':
with tf.variable_scope('instance_norm', reuse=reuse):
eps = 1e-5
mean, sigma = tf.nn.moments(input, [1, 2], keep_dims=True)
normalized = (input - mean) / (tf.sqrt(sigma) + eps)
out = normalized
# Apply momentum (not mendatory)
#c = input.get_shape()[-1]
#shift = tf.get_variable('shift', shape=[c],
# initializer=tf.zeros_initializer())
#scale = tf.get_variable('scale', shape=[c],
# initializer=tf.random_normal_initializer(1.0, 0.02))
#out = scale * normalized + shift
elif norm == 'batch':
with tf.variable_scope('batch_norm', reuse=reuse):
out = tf.contrib.layers.batch_norm(input,
decay=0.99, center=True,
scale=True, is_training=is_train,
updates_collections=None)
else:
out = input
return out
评论列表
文章目录