def get_logits_and_probs(logits=None,
probs=None,
multidimensional=False,
validate_args=False,
name="get_logits_and_probs"):
"""Converts logit to probabilities (or vice-versa), and returns both.
Args:
logits: Numeric `Tensor` representing log-odds.
probs: Numeric `Tensor` representing probabilities.
multidimensional: `Boolean`, default `False`.
If `True`, represents whether the last dimension of `logits` or `probs`,
a `[N1, N2, ... k]` dimensional tensor, representing the
logit or probability of `shape[-1]` classes.
validate_args: `Boolean`, default `False`. When `True`, either assert `0 <=
probs <= 1` (if not `multidimensional`) or that the last dimension of
`probs` sums to one.
name: A name for this operation (optional).
Returns:
logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
`1`, then the corresponding entry in the returned logit will be `-Inf` and
`Inf` respectively.
Raises:
ValueError: if neither `probs` nor `logits` were passed in, or both were.
"""
with ops.name_scope(name, values=[probs, logits]):
if (probs is None) == (logits is None):
raise ValueError("Must pass probs or logits, but not both.")
if probs is None:
logits = ops.convert_to_tensor(logits, name="logits")
if multidimensional:
return logits, nn.softmax(logits, name="probs")
return logits, math_ops.sigmoid(logits, name="probs")
probs = ops.convert_to_tensor(probs, name="probs")
if validate_args:
with ops.name_scope("validate_probs"):
one = constant_op.constant(1., probs.dtype)
dependencies = [check_ops.assert_non_negative(probs)]
if multidimensional:
dependencies += [assert_close(math_ops.reduce_sum(probs, -1), one,
message="probs does not sum to 1.")]
else:
dependencies += [check_ops.assert_less_equal(
probs, one, message="probs has components greater than 1.")]
probs = control_flow_ops.with_dependencies(dependencies, probs)
with ops.name_scope("logits"):
if multidimensional:
# Here we don't compute the multidimensional case, in a manner
# consistent with respect to the unidimensional case. We do so
# following the TF convention. Typically, you might expect to see
# logits = log(probs) - log(gather(probs, pivot)). A side-effect of
# being consistent with the TF approach is that the unidimensional case
# implicitly handles the second dimension but the multidimensional case
# explicitly keeps the pivot dimension.
return math_ops.log(probs), probs
return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs
distribution_util.py 文件源码
python
阅读 17
收藏 0
点赞 0
评论 0
评论列表
文章目录