metrics.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:dynamic-training-bench 作者: galeone 项目源码 文件源码
def confusion_matrix_op(logits, labels, num_classes):
    """Creates the operation to build the confusion matrix between the
    predictions and the labels. The number of classes are required to build
    the matrix correctly.
    Args:
        logits: a [batch_size, 1,1, num_classes] tensor or
                a [batch_size, num_classes] tensor
        labels: a [batch_size] tensor
    Returns:
        confusion_matrix_op: the confusion matrix tf op
    """
    with tf.variable_scope('confusion_matrix'):
        # handle fully convolutional classifiers
        logits_shape = logits.shape
        if len(logits_shape) == 4 and logits_shape[1:3] == [1, 1]:
            top_k_logits = tf.squeeze(logits, [1, 2])
        else:
            top_k_logits = logits

        # Extract the predicted label (top-1)
        _, top_predicted_label = tf.nn.top_k(top_k_logits, k=1, sorted=False)
        # (batch_size, k) -> k = 1 -> (batch_size)
        top_predicted_label = tf.squeeze(top_predicted_label, axis=1)

        return tf.confusion_matrix(
            labels, top_predicted_label, num_classes=num_classes)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号