main.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:attention-sum-reader 作者: cairoHy 项目源码 文件源码
def ensemble_test(test_data, models):
    data = [[] for _ in d_bucket]
    for test_document, test_question, test_answer, test_candidate in zip(*test_data):
        if len(test_document) <= d_bucket[0][0]:
            data[0].append((test_document, test_question, test_answer, test_candidate))
            continue
        if len(test_document) >= d_bucket[-1][-1]:
            data[len(models) - 1].append((test_document, test_question, test_answer, test_candidate))
            continue
        for bucket_id, (d_min, d_max) in enumerate(d_bucket):
            if d_min < len(test_document) < d_max:
                data[bucket_id].append((test_document, test_question, test_answer, test_candidate))
                continue

    acc, num = 0, 0
    for i in range(len(models)):
        num += len(data[i])
        logging.info("Start testing.\nTesting in {} samples.".format(len(data[i])))
        acc_i, _ = models[i].test(zip(*data[i]), batch_size=1)
        acc += acc_i
    logging.critical("Ensemble test done.\nAccuracy is {}".format(acc / num))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号