metric_ops.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _remove_squeezable_dimensions(predictions, labels, weights):
  """Squeeze last dim if needed.

  Squeezes `predictions` and `labels` if their rank differs by 1.
  Squeezes `weights` if its rank is 1 more than the new rank of `predictions`

  This will use static shape if available. Otherwise, it will add graph
  operations, which could result in a performance hit.

  Args:
    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
    labels: Label values, a `Tensor` whose dimensions match `predictions`.
    weights: optional `weights` tensor. It will be squeezed if its rank is 1
      more than the new rank of `predictions`

  Returns:
    Tuple of `predictions`, `labels` and `weights`, possibly with the last
    dimension squeezed.
  """
  predictions, labels = tensor_util.remove_squeezable_dimensions(
      predictions, labels)
  predictions.get_shape().assert_is_compatible_with(labels.get_shape())

  if weights is not None:
    predictions_shape = predictions.get_shape()
    predictions_rank = predictions_shape.ndims
    weights_shape = weights.get_shape()
    weights_rank = weights_shape.ndims

    if (predictions_rank is not None) and (weights_rank is not None):
      # Use static rank.
      if weights_rank - predictions_rank == 1:
        weights = array_ops.squeeze(weights, [-1])
    elif (weights_rank is None) or (
        weights_shape.dims[-1].is_compatible_with(1)):
      # Use dynamic rank
      weights = control_flow_ops.cond(
          math_ops.equal(array_ops.rank(weights),
                         math_ops.add(array_ops.rank(predictions), 1)),
          lambda: array_ops.squeeze(weights, [-1]),
          lambda: weights)
  return predictions, labels, weights
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号