chainer_model.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:char-rnn-text-generation 作者: yxtay 项目源码 文件源码
def generate_text(model, seed, length=512, top_n=10):
    """
    generates text of specified length from trained model
    with given seed character sequence.
    """
    logger.info("generating %s characters from top %s choices.", length, top_n)
    logger.info('generating with seed: "%s".', seed)
    generated = seed
    encoded = encode_text(seed).astype(np.int32)
    model.predictor.reset_state()

    with chainer.using_config("train", False), chainer.no_backprop_mode():
        for idx in encoded[:-1]:
            x = Variable(np.array([idx]))
            # input shape: [1]
            # set internal states
            model.predictor(x)

        next_index = encoded[-1]
        for i in range(length):
            x = Variable(np.array([next_index], dtype=np.int32))
            # input shape: [1]
            probs = F.softmax(model.predictor(x))
            # output shape: [1, vocab_size]
            next_index = sample_from_probs(probs.data.squeeze(), top_n)
            # append to sequence
            generated += ID2CHAR[next_index]

    logger.info("generated text: \n%s\n", generated)
    return generated
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号