calc_plex.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Tree-LSTM-LM 作者: vgene 项目源码 文件源码
def main():
    import sys
    reload(sys)
    sys.setdefaultencoding("utf-8")
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--model', type=str)
    argparser.add_argument('--test_file', type=str)
    argparser.add_argument('--cuda', action='store_true')
    args = argparser.parse_args()

    model = torch.load(args.model)
    print(model.vocab_size)
    batch_size = 1000
    tester = Tester(args.test_file, batch_size, model.mapping)
    perplexity = tester.calc_perplexity(model, cuda=args.cuda)
    print("Test File: {}, Perplexity:{}".format(args.test_file, perplexity))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号