metrics.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:dynamic-training-bench 作者: galeone 项目源码 文件源码
def accuracy_op(logits, labels):
    """Define the accuracy between predictions (logits) and labels.
    Args:
        logits: a [batch_size, 1,1, num_classes] tensor or
                a [batch_size, num_classes] tensor
        labels: a [batch_size] tensor
    Returns:
        accuracy: the accuracy op
    """

    with tf.variable_scope('accuracy'):
        # handle fully convolutional classifiers
        logits_shape = logits.shape
        if len(logits_shape) == 4 and logits_shape[1:3] == [1, 1]:
            top_k_logits = tf.squeeze(logits, [1, 2])
        else:
            top_k_logits = logits
        top_k_op = tf.nn.in_top_k(top_k_logits, labels, 1)
        accuracy = tf.reduce_mean(tf.cast(top_k_op, tf.float32))

    return accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号