def main():
import sys
reload(sys)
sys.setdefaultencoding("utf-8")
argparser = argparse.ArgumentParser()
argparser.add_argument('--model', type=str)
argparser.add_argument('--test_file', type=str)
argparser.add_argument('--cuda', action='store_true')
args = argparser.parse_args()
model = torch.load(args.model)
print(model.vocab_size)
batch_size = 1000
tester = Tester(args.test_file, batch_size, model.mapping)
perplexity = tester.calc_perplexity(model, cuda=args.cuda)
print("Test File: {}, Perplexity:{}".format(args.test_file, perplexity))
评论列表
文章目录