def _logits(logits_input, logits, logits_dimension):
"""Validate logits args, and create `logits` if necessary.
Exactly one of `logits_input` and `logits` must be provided.
Args:
logits_input: `Tensor` input to `logits`.
logits: `Tensor` output.
logits_dimension: Integer, last dimension of `logits`. This is used to
create `logits` from `logits_input` if `logits` is `None`; otherwise, it's
used to validate `logits`.
Returns:
`logits` `Tensor`.
Raises:
ValueError: if neither or both of `logits` and `logits_input` are supplied.
"""
if (logits_dimension is None) or (logits_dimension < 1):
raise ValueError("Invalid logits_dimension %s." % logits_dimension)
# If not provided, create logits.
if logits is None:
if logits_input is None:
raise ValueError("Neither logits nor logits_input supplied.")
return layers_lib.linear(logits_input, logits_dimension, scope="logits")
if logits_input is not None:
raise ValueError("Both logits and logits_input supplied.")
logits = ops.convert_to_tensor(logits, name="logits")
logits_dims = logits.get_shape().dims
if logits_dims is not None:
logits_dims[-1].assert_is_compatible_with(logits_dimension)
return logits
head.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录