def main(args):
if os.path.isfile(args.vocab_file):
en_dict, cn_dict, en_total_words, cn_total_words = pickle.load(open(args.vocab_file, "rb"))
else:
print("vocab file does not exit!")
exit(-1)
args.en_total_words = en_total_words
args.cn_total_words = cn_total_words
inv_en_dict = {v: k for k, v in en_dict.items()}
inv_cn_dict = {v: k for k, v in cn_dict.items()}
if os.path.isfile(args.model_file):
model = torch.load(args.model_file)
else:
print("model file does not exit!")
exit(-1)
if args.use_cuda:
model = model.cuda()
crit = utils.LanguageModelCriterion()
test_en, test_cn = utils.load_data(args.test_file)
args.num_test = len(test_en)
test_en, test_cn = utils.encode(test_en, test_cn, en_dict, cn_dict)
test_data = utils.gen_examples(test_en, test_cn, args.batch_size)
translate(model, test_data, en_dict, inv_en_dict, cn_dict, inv_cn_dict)
correct_count, loss, num_words = eval(model, test_data, args, crit)
loss = loss / num_words
acc = correct_count / num_words
print("test loss %s" % (loss) )
print("test accuracy %f" % (acc))
print("test total number of words %f" % (num_words))
评论列表
文章目录