def normalization(inputs, epsilon=1e-3, has_shift=True, has_scale=True,
activation_fn=None, scope='normalization'):
with tf.variable_scope(scope):
inputs_shape = inputs.get_shape()
inputs_rank = inputs_shape.ndims
axis = list(range(inputs_rank - 1))
mean, variance = tf.nn.moments(inputs, axis)
shift, scale = None, None
if has_shift:
shift = tf.get_variable('shift',
shape=inputs_shape[-1:],
dtype=inputs.dtype,
initializer=tf.zeros_initializer)
if has_scale:
scale = tf.get_variable('scale',
shape=inputs_shape[-1:],
dtype=inputs.dtype,
initializer=tf.ones_initializer)
x = tf.nn.batch_normalization(inputs, mean, variance, shift, scale, epsilon)
return x if activation_fn is None else activation_fn(x)
评论列表
文章目录