def _check_labels_and_scores(boolean_labels, scores, check_shape):
"""Check the rank of labels/scores, return tensor versions."""
with ops.name_scope('_check_labels_and_scores',
values=[boolean_labels, scores]):
boolean_labels = ops.convert_to_tensor(boolean_labels,
name='boolean_labels')
scores = ops.convert_to_tensor(scores, name='scores')
if boolean_labels.dtype != dtypes.bool:
raise ValueError(
'Argument boolean_labels should have dtype bool. Found: %s',
boolean_labels.dtype)
if check_shape:
labels_rank_1 = control_flow_ops.Assert(
math_ops.equal(1, array_ops.rank(boolean_labels)),
['Argument boolean_labels should have rank 1. Found: ',
boolean_labels.name, array_ops.shape(boolean_labels)])
scores_rank_1 = control_flow_ops.Assert(
math_ops.equal(1, array_ops.rank(scores)),
['Argument scores should have rank 1. Found: ', scores.name,
array_ops.shape(scores)])
with ops.control_dependencies([labels_rank_1, scores_rank_1]):
return boolean_labels, scores
else:
return boolean_labels, scores
评论列表
文章目录