test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:pytorch-PersonReID 作者: huaijin-chen 项目源码 文件源码
def best_test(model, _loader, model_path=None, is_cuda=True):
    if not model_path is None:
        model.load_full_weights(model_path)
        print('loaded model file: {:s}'.format(model_path))
    if is_cuda:
        model = model.cuda()
    model.eval()
    total_num = 0
    batch_num = len(_loader)
    for batch_idx, (data_a, data_p, data_n,target) in enumerate(_loader):
        if is_cuda:
            data_a = data_a.cuda()
            data_p = data_p.cuda()
            data_n = data_n.cuda()
            target = target.cuda()

        data_a = Variable(data_a, volatile=True)
        data_p = Variable(data_p, volatile=True)
        data_n = Variable(data_n, volatile=True)
        target = Variable(target)

        out_a = model(data_a)
        out_p = model(data_p)
        out_n =  model(data_n)
        current_d_a_p = F.pairwise_distance(out_a,out_p)
        current_d_a_n = F.pairwise_distance(out_a,out_n)
        if batch_idx == 0:
            d_a_p = current_d_a_p
            d_a_n = current_d_a_n
        else:
            d_a_p = torch.cat((d_a_p, current_d_a_p), 0)
            d_a_n = torch.cat((d_a_n, current_d_a_n), 0)
        total_num += 2*data_a.size()[0]

    mean_d_a_p = d_a_p.mean().data[0]
    mean_d_a_n = d_a_n.mean().data[0]
    start = min(mean_d_a_n, mean_d_a_p)
    end = max(mean_d_a_n, mean_d_a_p)
    best_thre = 0
    best_num = 0
    thre_step = 0.05

    for val in torch.arange(start, end+thre_step, thre_step):
        num = (((d_a_p <= val).float()).sum() + (d_a_n > val).float().sum()).data[0]
        #print(num, val)
        if num > best_num:
            best_num = num
            best_thre = val
    return best_thre, best_num/total_num, mean_d_a_p, mean_d_a_n
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号