test.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:nmt-seq2seq 作者: ZeweiChu 项目源码 文件源码
def main(args):

    if os.path.isfile(args.vocab_file):
        en_dict, cn_dict, en_total_words, cn_total_words = pickle.load(open(args.vocab_file, "rb"))
    else:
        print("vocab file does not exit!")
        exit(-1)

    args.en_total_words = en_total_words
    args.cn_total_words = cn_total_words
    inv_en_dict = {v: k for k, v in en_dict.items()}
    inv_cn_dict = {v: k for k, v in cn_dict.items()}



    if os.path.isfile(args.model_file):
        model = torch.load(args.model_file)
    else:
        print("model file does not exit!")
        exit(-1)

    if args.use_cuda:
        model = model.cuda()

    crit = utils.LanguageModelCriterion()

    test_en, test_cn = utils.load_data(args.test_file)
    args.num_test = len(test_en)
    test_en, test_cn = utils.encode(test_en, test_cn, en_dict, cn_dict)
    test_data = utils.gen_examples(test_en, test_cn, args.batch_size)

    translate(model, test_data, en_dict, inv_en_dict, cn_dict, inv_cn_dict)

    correct_count, loss, num_words = eval(model, test_data, args, crit)
    loss = loss / num_words
    acc = correct_count / num_words
    print("test loss %s" % (loss) )
    print("test accuracy %f" % (acc))
    print("test total number of words %f" % (num_words))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号