chatbot.py 文件源码

python
阅读 104 收藏 0 点赞 0 评论 0

项目:chatbot-rnn 作者: zenixls2 项目源码 文件源码
def sample_main(args):
    model_path, config_path, vocab_path = get_paths(args.save_dir)
    # Arguments passed to sample.py direct us to a saved model.
    # Load the separate arguments by which that model was previously trained.
    # That's saved_args. Use those to load the model.
    with open(config_path) as f:
        saved_args = cPickle.load(f)
    # Separately load chars and vocab from the save directory.
    with open(vocab_path) as f:
        chars, vocab = cPickle.load(f)
    # Create the model from the saved arguments, in inference mode.
    print("Creating model...")
    net = Model(saved_args, True)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(net.save_variables_list())
        # Restore the saved variables, replacing the initialized values.
        print("Restoring weights...")
        saver.restore(sess, model_path)
        chatbot(net, sess, chars, vocab, args.n, args.beam_width, args.relevance, args.temperature)
        #beam_sample(net, sess, chars, vocab, args.n, args.prime,
            #args.beam_width, args.relevance, args.temperature)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号