rnn.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:magenta 作者: tensorflow 项目源码 文件源码
def layer_norm_all(h,
                   batch_size,
                   base,
                   num_units,
                   scope='layer_norm',
                   reuse=False,
                   gamma_start=1.0,
                   epsilon=1e-3,
                   use_bias=True):
  """Layer Norm (faster version, but not using defun)."""
  # Performs layer norm on multiple base at once (ie, i, g, j, o for lstm)
  # Reshapes h in to perform layer norm in parallel
  h_reshape = tf.reshape(h, [batch_size, base, num_units])
  mean = tf.reduce_mean(h_reshape, [2], keep_dims=True)
  var = tf.reduce_mean(tf.square(h_reshape - mean), [2], keep_dims=True)
  epsilon = tf.constant(epsilon)
  rstd = tf.rsqrt(var + epsilon)
  h_reshape = (h_reshape - mean) * rstd
  # reshape back to original
  h = tf.reshape(h_reshape, [batch_size, base * num_units])
  with tf.variable_scope(scope):
    if reuse:
      tf.get_variable_scope().reuse_variables()
    gamma = tf.get_variable(
        'ln_gamma', [4 * num_units],
        initializer=tf.constant_initializer(gamma_start))
    if use_bias:
      beta = tf.get_variable(
          'ln_beta', [4 * num_units], initializer=tf.constant_initializer(0.0))
  if use_bias:
    return gamma * h + beta
  return gamma * h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号