top_retrieval.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:BuboQA 作者: castorini 项目源码 文件源码
def predict(dataset_iter=test_iter, dataset=test, data_name="test"):
    print("Dataset: {}".format(data_name))
    model.eval()
    dataset_iter.init_epoch()

    n_correct = 0
    fname = "{}.txt".format(data_name)
    results_file = open(os.path.join(results_path, fname), 'w')
    n_retrieved = 0

    fid = open(os.path.join(args.data_dir,"lineids_{}.txt".format(data_name)))
    sent_id = [x.strip() for x in fid.readlines()]

    for data_batch_idx, data_batch in enumerate(dataset_iter):
        scores = model(data_batch)
        if args.dataset == 'RelationPrediction':
            n_correct += (torch.max(scores, 1)[1].view(data_batch.relation.size()).data == data_batch.relation.data).sum()
            # Get top k
            top_k_scores, top_k_indices = torch.topk(scores, k=args.hits, dim=1, sorted=True)  # shape: (batch_size, k)
            top_k_scores_array = top_k_scores.cpu().data.numpy()
            top_k_indices_array = top_k_indices.cpu().data.numpy()
            top_k_relatons_array = index2tag[top_k_indices_array]
            for i, (relations_row, scores_row) in enumerate(zip(top_k_relatons_array, top_k_scores_array)):
                index = (data_batch_idx * args.batch_size) + i
                example = data_batch.dataset.examples[index]
                for j, (rel, score) in enumerate(zip(relations_row, scores_row)):
                    if (rel == example.relation):
                        label = 1
                        n_retrieved += 1
                    else:
                        label = 0
                    results_file.write(
                        "{} %%%% {} %%%% {} %%%% {}\n".format( sent_id[index], rel, label, score))
        else:
            print("Wrong Dataset")
            exit()

    if args.dataset == 'RelationPrediction':
        P = 1. * n_correct / len(dataset)
        print("{} Precision: {:10.6f}%".format(data_name, 100. * P))
        print("no. retrieved: {} out of {}".format(n_retrieved, len(dataset)))
        retrieval_rate = 100. * n_retrieved / len(dataset)
        print("{} Retrieval Rate {:10.6f}".format(data_name, retrieval_rate))
    else:
        print("Wrong dataset")
        exit()

# run the model on the dev set and write the output to a file
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号