extra.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:self-supervision 作者: gustavla 项目源码 文件源码
def moments(x, axes, shift=None, name=None, keep_dims=False):
  """Calculate the mean and variance of `x`.

  The mean and variance are calculated by aggregating the contents of `x`
  across `axes`.  If `x` is 1-D and `axes = [0]` this is just the mean
  and variance of a vector.

  Note: for numerical stability, when shift=None, the true mean
  would be computed and used as shift.

  When using these moments for batch normalization (see
  `tf.nn.batch_normalization`):

   * for so-called "global normalization", used with convolutional filters with
     shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
   * for simple batch normalization pass `axes=[0]` (batch only).

  Args:
    x: A `Tensor`.
    axes: Array of ints.  Axes along which to compute mean and
      variance.
    shift: A `Tensor` containing the value by which to shift the data for
      numerical stability, or `None` in which case the true mean of the data is
      used as shift. A shift close to the true mean provides the most
      numerically stable results.
    name: Name used to scope the operations that compute the moments.
    keep_dims: produce moments with the same dimensionality as the input.

  Returns:
    Two `Tensor` objects: `mean` and `variance`.
  """
  #with ops.name_scope(name, "moments", [x, axes, shift]):
  if 1:
    # The dynamic range of fp16 is too limited to support the collection of
    # sufficient statistics. As a workaround we simply perform the operations
    # on 32-bit floats before converting the mean and variance back to fp16
    y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
    if shift is None:
      # Compute true mean while keeping the dims for proper broadcasting.
      shift = array_ops.stop_gradient(
          math_ops.reduce_mean(y, axes, keep_dims=True))
    else:
      shift = math_ops.cast(shift, y.dtype)
    counts, m_ss, v_ss, shift = nn.sufficient_statistics(
        y, axes, shift=shift, keep_dims=keep_dims, name=name+'_statistics')
    # Reshape shift as needed.
    shift = array_ops.reshape(shift, array_ops.shape(m_ss))
    shift.set_shape(m_ss.get_shape())
    with ops.control_dependencies([counts, m_ss, v_ss]):
      mean, variance = normalize_moments(counts, m_ss, v_ss, shift, name=name)
      if x.dtype == dtypes.float16:
        return (math_ops.cast(mean, dtypes.float16),
                math_ops.cast(variance, dtypes.float16))
      else:
        return (mean, variance)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号