generate.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:chainer-glu 作者: musyoku 项目源码 文件源码
def main(args):
    model = load_model(args.model_dir)
    assert model is not None

    vocab, vocab_inv = load_vocab(args.model_dir)
    assert vocab is not None
    assert vocab_inv is not None

    vocab_size = model.vocab_size

    with chainer.using_config("train", False):
        for n in xrange(args.num_generate):
            word_ids = np.arange(0, vocab_size, dtype=np.int32)
            token = ID_BOS
            x = np.asarray([[token]]).astype(np.int32)
            model.reset_state()
            while token != ID_EOS and x.shape[1] < args.max_sentence_length:
                u = model.forward_one_step(x)
                p = F.softmax(u).data[-1]
                token = np.random.choice(word_ids, size=1, p=p)
                x = np.append(x, np.asarray([token]).astype(np.int32), axis=1)

            sentence = []
            for token in x[0]:
                word = vocab_inv[token]
                sentence.append(word)
            print(" ".join(sentence))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号