def accuracy_op(predictions, targets, num_classes=5):
"""
Computes accuracy metric
Args:
predictions: 2D tensor/array, predictions of the network
targets: 2D tensor/array, ground truth labels of the network
num_classes: int, num_classes of the network
Returns:
accuracy
"""
with tf.name_scope('Accuracy'):
if targets.ndim == 2:
targets = np.argmax(targets, axis=1)
if predictions.ndim == 1:
predictions = one_hot(predictions, m=num_classes)
acc = accuracy_score(targets, np.argmax(predictions, axis=1))
return acc
评论列表
文章目录