nnlib.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:revnet-public 作者: renmengye 项目源码 文件源码
def batch_norm(x,
               is_training,
               gamma=None,
               beta=None,
               axes=[0, 1, 2],
               eps=1e-10,
               name="bn_out",
               decay=0.99,
               dtype=tf.float32):
  """Applies batch normalization.
    Collect mean and variances on x except the last dimension. And apply
    normalization as below:
    x_ = gamma * (x - mean) / sqrt(var + eps) + beta

    Args:
      x: Input tensor, [B, ...].
      n_out: Integer, depth of input variable.
      gamma: Scaling parameter.
      beta: Bias parameter.
      axes: Axes to collect statistics.
      eps: Denominator bias.

    Returns:
      normed: Batch-normalized variable.
      mean: Mean used for normalization (optional).
  """
  n_out = x.get_shape()[-1]
  try:
    n_out = int(n_out)
    shape = [n_out]
  except:
    shape = None
  emean = tf.get_variable(
      "ema_mean",
      shape=shape,
      trainable=False,
      dtype=dtype,
      initializer=tf.constant_initializer(
          0.0, dtype=dtype))
  evar = tf.get_variable(
      "ema_var",
      shape=shape,
      trainable=False,
      dtype=dtype,
      initializer=tf.constant_initializer(
          1.0, dtype=dtype))
  if is_training:
    mean, var = tf.nn.moments(x, axes, name="moments")
    ema_mean_op = tf.assign_sub(emean, (emean - mean) * (1 - decay))
    ema_var_op = tf.assign_sub(evar, (evar - var) * (1 - decay))
    normed = tf.nn.batch_normalization(
        x, mean, var, beta, gamma, eps, name=name)
    return normed, [ema_mean_op, ema_var_op]
  else:
    normed = tf.nn.batch_normalization(
        x, emean, evar, beta, gamma, eps, name=name)
    return normed, None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号