def ensamble(file1,file2,label_path=label_path, test_data_path=test_data_path,result_csv=None):
import torch as t
import numpy as np
if result_csv is None:
import time
result_csv = time.strftime('%y%m%d_%H%M%S.csv')
a = t.load(file1)
b = t.load(file2)
r = 9.0*a+b
result = r.topk(5,1)[1]
index2qid = np.load(test_data_path)['index2qid'].item()
with open(label_path) as f: label2qid = json.load(f)['id2label']
rows = range(result.size(0))
for ii,item in enumerate(result):
rows[ii] = [index2qid[ii]] + [label2qid[str(_)] for _ in item ]
import csv
with open(result_csv,'w') as f:
writer = csv.writer(f)
writer.writerows(rows)
评论列表
文章目录