head.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def _logits(logits_input, logits, logits_dimension):
  """Validate logits args, and create `logits` if necessary.

  Exactly one of `logits_input` and `logits` must be provided.

  Args:
    logits_input: `Tensor` input to `logits`.
    logits: `Tensor` output.
    logits_dimension: Integer, last dimension of `logits`. This is used to
      create `logits` from `logits_input` if `logits` is `None`; otherwise, it's
      used to validate `logits`.

  Returns:
    `logits` `Tensor`.

  Raises:
    ValueError: if neither or both of `logits` and `logits_input` are supplied.
  """
  if (logits_dimension is None) or (logits_dimension < 1):
    raise ValueError("Invalid logits_dimension %s." % logits_dimension)

  # If not provided, create logits.
  if logits is None:
    if logits_input is None:
      raise ValueError("Neither logits nor logits_input supplied.")
    return layers_lib.linear(logits_input, logits_dimension, scope="logits")

  if logits_input is not None:
    raise ValueError("Both logits and logits_input supplied.")

  logits = ops.convert_to_tensor(logits, name="logits")
  logits_dims = logits.get_shape().dims
  if logits_dims is not None:
    logits_dims[-1].assert_is_compatible_with(logits_dimension)

  return logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号