def sample_main(args):
model_path, config_path, vocab_path = get_paths(args.save_dir)
# Arguments passed to sample.py direct us to a saved model.
# Load the separate arguments by which that model was previously trained.
# That's saved_args. Use those to load the model.
with open(config_path) as f:
saved_args = cPickle.load(f)
# Separately load chars and vocab from the save directory.
with open(vocab_path) as f:
chars, vocab = cPickle.load(f)
# Create the model from the saved arguments, in inference mode.
print("Creating model...")
net = Model(saved_args, True)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
tf.global_variables_initializer().run()
saver = tf.train.Saver(net.save_variables_list())
# Restore the saved variables, replacing the initialized values.
print("Restoring weights...")
saver.restore(sess, model_path)
chatbot(net, sess, chars, vocab, args.n, args.beam_width, args.relevance, args.temperature)
#beam_sample(net, sess, chars, vocab, args.n, args.prime,
#args.beam_width, args.relevance, args.temperature)
评论列表
文章目录