cnnT3.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:future-price-predictor 作者: htfy96 项目源码 文件源码
def evaluate(model, testloader, args, use_cuda=False):
    correct = 0
    total = 0
    class_correct = list(0. for i in range(2))
    class_total = list(0. for i in range(2))
    for i, data in enumerate(testloader, 0):
        if i == 20:
            break;
        inputs, targets = data
        inputs = inputs.unsqueeze(1)
        targets = target_onehot_to_classnum_tensor(targets)
        if use_cuda and cuda_ava:
            inputs = Variable(inputs.float().cuda())
            targets = targets.cuda()
        else:
            inputs = Variable(inputs.float())
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum()
        c = (predicted == targets).squeeze()
        for i in range(args.batch_size):
            target = targets[i]
            class_correct[target] += c[i]
            class_total[target] += 1

    print("Accuracy of the network is: %.5f %%" % (correct / total * 100))

    for i in range(2):
        if class_total[i] == 0:
            print("Accuracy of %1s : %1s %% (%1d / %1d)" % (classes[i], "NaN", class_correct[i], class_total[i]))
        else:
            print("Accuracy of %1s : %.5f %% (%1d / %1d)" % (classes[i], class_correct[i] / class_total[i] * 100, class_correct[i], class_total[i]))

    return correct / total
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号