def _remove_squeezable_dimensions(predictions, labels, weights):
"""Squeeze last dim if needed.
Squeezes `predictions` and `labels` if their rank differs by 1.
Squeezes `weights` if its rank is 1 more than the new rank of `predictions`
This will use static shape if available. Otherwise, it will add graph
operations, which could result in a performance hit.
Args:
predictions: Predicted values, a `Tensor` of arbitrary dimensions.
labels: Label values, a `Tensor` whose dimensions match `predictions`.
weights: optional `weights` tensor. It will be squeezed if its rank is 1
more than the new rank of `predictions`
Returns:
Tuple of `predictions`, `labels` and `weights`, possibly with the last
dimension squeezed.
"""
predictions, labels = tensor_util.remove_squeezable_dimensions(
predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
if weights is not None:
predictions_shape = predictions.get_shape()
predictions_rank = predictions_shape.ndims
weights_shape = weights.get_shape()
weights_rank = weights_shape.ndims
if (predictions_rank is not None) and (weights_rank is not None):
# Use static rank.
if weights_rank - predictions_rank == 1:
weights = array_ops.squeeze(weights, [-1])
elif (weights_rank is None) or (
weights_shape.dims[-1].is_compatible_with(1)):
# Use dynamic rank
weights = control_flow_ops.cond(
math_ops.equal(array_ops.rank(weights),
math_ops.add(array_ops.rank(predictions), 1)),
lambda: array_ops.squeeze(weights, [-1]),
lambda: weights)
return predictions, labels, weights
评论列表
文章目录