eval_qrnn.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:chainer-qrnn 作者: butsugiri 项目源码 文件源码
def main(args):
    # load config file and obtain embed dimension and hidden dimension
    with open(args.config_path, 'r') as fi:
        config = json.load(fi)
        embed_dim = config["dim"]
        hidden_dim = config["unit"]
        print("Embedding Dimension: {}\nHidden Dimension: {}\n".format(embed_dim, hidden_dim), file=sys.stderr)

    # load data
    dp = DataProcessor(data_path=config["data"], test_run=False)
    dp.prepare_dataset()

    # create model
    vocab = dp.vocab
    model = RecNetClassifier(QRNNLangModel(n_vocab=len(vocab), embed_dim=embed_dim, out_size=hidden_dim))

    # load parameters
    print("loading paramters to model...", end='', file=sys.stderr, flush=True)
    S.load_npz(filename=args.model_path, obj=model)
    print("done.", file=sys.stderr, flush=True)

    # create iterators from loaded data
    bprop_len = config["bproplen"]
    test_data = dp.test_data
    test_iter = ParallelSequentialIterator(test_data, 1, repeat=False, bprop_len=bprop_len)

    # evaluate the model
    print('testing...', end='', file=sys.stderr, flush=True)
    model.predictor.reset_state()
    model.predictor.train = False
    evaluator = extensions.Evaluator(test_iter, model, converter=convert)
    result = evaluator()
    print('done.\n', file=sys.stderr, flush=True)
    print('Perplexity: {}'.format(np.exp(float(result['main/loss']))), end='', file=sys.stderr, flush=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号