def layernorm(x, axis, name):
'''
Layer normalization (Ba, 2016)
J: Z-normalization using all nodes of the layer on a per-sample basis.
Input:
`x`: channel_first/NCHW format! (or fully-connected)
`axis`: list
`name`: must be assigned
Example:
# axis = [1, 2, 3]
# x = tf.random_normal([64, 3, 10, 10])
# name = 'D_layernorm'
Return:
(x - u)/s * scale + offset
Source:
https://github.com/igul222/improved_wgan_training/blob/master/tflib/ops/layernorm.py
'''
mean, var = tf.nn.moments(x, axis, keep_dims=True)
n_neurons = x.get_shape().as_list()[axis[0]]
offset = tf.get_variable(
name+'.offset',
shape=[n_neurons] + [1 for _ in range(len(axis) -1)],
initializer=tf.zeros_initializer
)
scale = tf.get_variable(
name+'.scale',
shape=[n_neurons] + [1 for _ in range(len(axis) -1)],
initializer=tf.ones_initializer
)
return tf.nn.batch_normalization(x, mean, var, offset, scale, 1e-5)
评论列表
文章目录