def confusion_matrix_op(logits, labels, num_classes):
"""Creates the operation to build the confusion matrix between the
predictions and the labels. The number of classes are required to build
the matrix correctly.
Args:
logits: a [batch_size, 1,1, num_classes] tensor or
a [batch_size, num_classes] tensor
labels: a [batch_size] tensor
Returns:
confusion_matrix_op: the confusion matrix tf op
"""
with tf.variable_scope('confusion_matrix'):
# handle fully convolutional classifiers
logits_shape = logits.shape
if len(logits_shape) == 4 and logits_shape[1:3] == [1, 1]:
top_k_logits = tf.squeeze(logits, [1, 2])
else:
top_k_logits = logits
# Extract the predicted label (top-1)
_, top_predicted_label = tf.nn.top_k(top_k_logits, k=1, sorted=False)
# (batch_size, k) -> k = 1 -> (batch_size)
top_predicted_label = tf.squeeze(top_predicted_label, axis=1)
return tf.confusion_matrix(
labels, top_predicted_label, num_classes=num_classes)
评论列表
文章目录