distribution_util.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def get_logits_and_probs(logits=None,
                         probs=None,
                         multidimensional=False,
                         validate_args=False,
                         name="get_logits_and_probs"):
  """Converts logit to probabilities (or vice-versa), and returns both.

  Args:
    logits: Numeric `Tensor` representing log-odds.
    probs: Numeric `Tensor` representing probabilities.
    multidimensional: `Boolean`, default `False`.
      If `True`, represents whether the last dimension of `logits` or `probs`,
      a `[N1, N2, ... k]` dimensional tensor, representing the
      logit or probability of `shape[-1]` classes.
    validate_args: `Boolean`, default `False`.  When `True`, either assert `0 <=
      probs <= 1` (if not `multidimensional`) or that the last dimension of
      `probs` sums to one.
    name: A name for this operation (optional).

  Returns:
    logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
      `1`, then the corresponding entry in the returned logit will be `-Inf` and
      `Inf` respectively.

  Raises:
    ValueError: if neither `probs` nor `logits` were passed in, or both were.
  """
  with ops.name_scope(name, values=[probs, logits]):
    if (probs is None) == (logits is None):
      raise ValueError("Must pass probs or logits, but not both.")

    if probs is None:
      logits = ops.convert_to_tensor(logits, name="logits")
      if multidimensional:
        return logits, nn.softmax(logits, name="probs")
      return logits, math_ops.sigmoid(logits, name="probs")

    probs = ops.convert_to_tensor(probs, name="probs")
    if validate_args:
      with ops.name_scope("validate_probs"):
        one = constant_op.constant(1., probs.dtype)
        dependencies = [check_ops.assert_non_negative(probs)]
        if multidimensional:
          dependencies += [assert_close(math_ops.reduce_sum(probs, -1), one,
                                        message="probs does not sum to 1.")]
        else:
          dependencies += [check_ops.assert_less_equal(
              probs, one, message="probs has components greater than 1.")]
        probs = control_flow_ops.with_dependencies(dependencies, probs)

    with ops.name_scope("logits"):
      if multidimensional:
        # Here we don't compute the multidimensional case, in a manner
        # consistent with respect to the unidimensional case. We do so
        # following the TF convention. Typically, you might expect to see
        # logits = log(probs) - log(gather(probs, pivot)). A side-effect of
        # being consistent with the TF approach is that the unidimensional case
        # implicitly handles the second dimension but the multidimensional case
        # explicitly keeps the pivot dimension.
        return math_ops.log(probs), probs
      return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号