test.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pytorch-PersonReID 作者: huaijin-chen 项目源码 文件源码
def test(model, test_loader, epoch, margin, threshlod, is_cuda=True, log_interval=1000):
    model.eval()
    test_loss = AverageMeter()
    accuracy = 0
    num_p = 0
    total_num = 0
    batch_num = len(test_loader)
    for batch_idx, (data_a, data_p, data_n,target) in enumerate(test_loader):
        if is_cuda:
            data_a = data_a.cuda()
            data_p = data_p.cuda()
            data_n = data_n.cuda()
            target = target.cuda()

        data_a = Variable(data_a, volatile=True)
        data_p = Variable(data_p, volatile=True)
        data_n = Variable(data_n, volatile=True)
        target = Variable(target)

        out_a = model(data_a)
        out_p = model(data_p)
        out_n = model(data_n)

        loss = F.triplet_margin_loss(out_a,out_p,out_n, margin)

        dist1 = F.pairwise_distance(out_a,out_p)
        dist2 = F.pairwise_distance(out_a,out_n)

        num = ((dist1 < threshlod).float().sum() + (dist2 > threshlod).float().sum()).data[0]
        num_p += num
        num_p = 1.0 * num_p
        total_num += data_a.size()[0] * 2
        #print('num--num_p -- total',  num, num_p , total_num)
        test_loss.update(loss.data[0])
        if (batch_idx + 1) % log_interval == 0:
            accuracy_tmp = num_p / total_num    
            print('Test- Epoch {:04d}\tbatch:{:06d}/{:06d}\tAccuracy:{:.04f}\tloss:{:06f}'\
                    .format(epoch, batch_idx+1, batch_num, accuracy_tmp, test_loss.avg))
            test_loss.reset()

    accuracy = num_p / total_num
    return accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号