def sample_generator(args, num_samples = 10):
with open(os.path.join(args.save_dir_GAN, 'config.pkl')) as f:
saved_args = cPickle.load(f)
with open(os.path.join(args.save_dir_GAN, 'real_beer_vocab.pkl')) as f:
chars, vocab = cPickle.load(f)
generator = Generator(saved_args, is_training = False, batch = True)
with tf.Session() as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(tf.all_variables())
ckpt = tf.train.get_checkpoint_state(args.save_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
# for i in range(num_samples):
# print 'Review',i,':', generator.generate(sess, chars, vocab, args.n, args.prime), '\n'
return generator.generate_batch(sess, saved_args, chars, vocab)
评论列表
文章目录