def generate_samples(generator, args, sess, num_samples=500):
'''Generate samples from the current version of the GAN'''
samples = []
with open(os.path.join(args.save_dir_GAN, 'config.pkl')) as f:
saved_args = cPickle.load(f)
with open(os.path.join(args.save_dir_GAN, args.vocab_file)) as f:
chars, vocab = cPickle.load(f)
logging.debug('Loading GAN parameters to Generator...')
gen_vars = [v for v in tf.all_variables() if v.name.startswith('sampler/')]
gen_dict = {}
for v in gen_vars:
# Key: op.name in GAN Checkpoint file
# Value: Local generator Variable
gen_dict[v.op.name.replace('sampler/','')] = v
gen_saver = tf.train.Saver(gen_dict)
ckpt = tf.train.get_checkpoint_state(args.save_dir_GAN)
if ckpt and ckpt.model_checkpoint_path:
gen_saver.restore(sess, ckpt.model_checkpoint_path)
for _ in xrange(num_samples / args.batch_size):
samples.append(generator.generate_samples(sess, saved_args, chars, vocab, args.n))
return samples
评论列表
文章目录