def ensemble_test(test_data, models):
data = [[] for _ in d_bucket]
for test_document, test_question, test_answer, test_candidate in zip(*test_data):
if len(test_document) <= d_bucket[0][0]:
data[0].append((test_document, test_question, test_answer, test_candidate))
continue
if len(test_document) >= d_bucket[-1][-1]:
data[len(models) - 1].append((test_document, test_question, test_answer, test_candidate))
continue
for bucket_id, (d_min, d_max) in enumerate(d_bucket):
if d_min < len(test_document) < d_max:
data[bucket_id].append((test_document, test_question, test_answer, test_candidate))
continue
acc, num = 0, 0
for i in range(len(models)):
num += len(data[i])
logging.info("Start testing.\nTesting in {} samples.".format(len(data[i])))
acc_i, _ = models[i].test(zip(*data[i]), batch_size=1)
acc += acc_i
logging.critical("Ensemble test done.\nAccuracy is {}".format(acc / num))
评论列表
文章目录