def layer_norm_all(h,
batch_size,
base,
num_units,
scope='layer_norm',
reuse=False,
gamma_start=1.0,
epsilon=1e-3,
use_bias=True):
"""Layer Norm (faster version, but not using defun)."""
# Performs layer norm on multiple base at once (ie, i, g, j, o for lstm)
# Reshapes h in to perform layer norm in parallel
h_reshape = tf.reshape(h, [batch_size, base, num_units])
mean = tf.reduce_mean(h_reshape, [2], keep_dims=True)
var = tf.reduce_mean(tf.square(h_reshape - mean), [2], keep_dims=True)
epsilon = tf.constant(epsilon)
rstd = tf.rsqrt(var + epsilon)
h_reshape = (h_reshape - mean) * rstd
# reshape back to original
h = tf.reshape(h_reshape, [batch_size, base * num_units])
with tf.variable_scope(scope):
if reuse:
tf.get_variable_scope().reuse_variables()
gamma = tf.get_variable(
'ln_gamma', [4 * num_units],
initializer=tf.constant_initializer(gamma_start))
if use_bias:
beta = tf.get_variable(
'ln_beta', [4 * num_units], initializer=tf.constant_initializer(0.0))
if use_bias:
return gamma * h + beta
return gamma * h
评论列表
文章目录