train.py 文件源码

python
阅读 48 收藏 0 点赞 0 评论 0

项目:DeepLearning_PlantDiseases 作者: MarkoArsenovic 项目源码 文件源码
def evaluate_stats(net, testloader):
    stats = {}
    correct = 0
    total = 0

    before = time.time()
    for i, data in enumerate(testloader, 0):
        images, labels = data

        if use_gpu:
            images, labels = (images.cuda()), (labels.cuda(async=True))

        outputs = net(Variable(images))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
    accuracy = correct / total
    stats['accuracy'] = accuracy
    stats['eval_time'] = time.time() - before

    print('Accuracy on test images: %f' % accuracy)
    return stats
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号