def sample(args):
with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
saved_args = cPickle.load(f)
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
chars, vocab = cPickle.load(f)
model = Model(saved_args, True)
val_loss_file = args.save_dir + '/val_loss.json'
with tf.Session() as sess:
saver = tf.train.Saver(tf.all_variables())
if os.path.exists(val_loss_file):
with open(val_loss_file, "r") as text_file:
text = text_file.read()
loss_json = json.loads(text)
losses = loss_json.keys()
losses.sort(key=lambda x: float(x))
loss = losses[0]
model_checkpoint_path = loss_json[loss]['checkpoint_path']
#print(model_checkpoint_path)
saver.restore(sess, model_checkpoint_path)
result = model.sample(sess, chars, vocab, args.n, args.prime, args.sample_rule, args.temperature)
print(result) #add this back in later, not sure why its not working
output = "/data/output/"+ str(int(time.time())) + ".txt"
with open(output, "w") as text_file:
text_file.write(result)
print(output)
评论列表
文章目录