ensamble.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:PyTorchText 作者: chenyuntc 项目源码 文件源码
def ensamble(file1,file2,label_path=label_path,     test_data_path=test_data_path,result_csv=None):
    import torch as t 
    import numpy as np
    if result_csv is None:
        import time
        result_csv = time.strftime('%y%m%d_%H%M%S.csv')
    a = t.load(file1)
    b = t.load(file2)
    r = 9.0*a+b
    result = r.topk(5,1)[1]

    index2qid = np.load(test_data_path)['index2qid'].item()
    with open(label_path) as f:   label2qid = json.load(f)['id2label']
    rows = range(result.size(0))
    for ii,item in enumerate(result):
        rows[ii] = [index2qid[ii]] + [label2qid[str(_)] for _ in item ]
    import csv
    with open(result_csv,'w') as f:
        writer = csv.writer(f)
        writer.writerows(rows)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号