train.py 文件源码

python
阅读 68 收藏 0 点赞 0 评论 0

项目:ResNeXt.pytorch 作者: prlz77 项目源码 文件源码
def test():
        net.eval()
        loss_avg = 0.0
        correct = 0
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = torch.autograd.Variable(data.cuda()), torch.autograd.Variable(target.cuda())

            # forward
            output = net(data)
            loss = F.cross_entropy(output, target)

            # accuracy
            pred = output.data.max(1)[1]
            correct += pred.eq(target.data).sum()

            # test loss average
            loss_avg += loss.data[0]

        state['test_loss'] = loss_avg / len(test_loader)
        state['test_accuracy'] = correct / len(test_loader.dataset)


    # Main loop
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号