low_shot.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:low-shot-shrink-hallucinate 作者: facebookresearch 项目源码 文件源码
def eval_loop(data_loader, model, base_classes, novel_classes):
    model = model.eval()
    top1 = None
    top5 = None
    all_labels = None
    for i, (x,y) in enumerate(data_loader):
        x = Variable(x.cuda())
        scores = model(x)
        top1_this, top5_this = perelement_accuracy(scores.data, y)
        top1 = top1_this if top1 is None else np.concatenate((top1, top1_this))
        top5 = top5_this if top5 is None else np.concatenate((top5, top5_this))
        all_labels = y.numpy() if all_labels is None else np.concatenate((all_labels, y.numpy()))

    is_novel = np.in1d(all_labels, novel_classes)
    is_base = np.in1d(all_labels, base_classes)
    is_either = is_novel | is_base
    top1_novel = np.mean(top1[is_novel])
    top1_base = np.mean(top1[is_base])
    top1_all = np.mean(top1[is_either])
    top5_novel = np.mean(top5[is_novel])
    top5_base = np.mean(top5[is_base])
    top5_all = np.mean(top5[is_either])
    return np.array([top1_novel, top5_novel, top1_base, top5_base, top1_all, top5_all])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号