train_gan_new.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:CharacterGAN 作者: liamb315 项目源码 文件源码
def generate_samples(generator, args, sess, num_samples=500):
    '''Generate samples from the current version of the GAN'''
    samples = []

    with open(os.path.join(args.save_dir_GAN, 'config.pkl')) as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(args.save_dir_GAN, args.vocab_file)) as f:
        chars, vocab = cPickle.load(f)

    logging.debug('Loading GAN parameters to Generator...')
    gen_vars = [v for v in tf.all_variables() if v.name.startswith('sampler/')]
    gen_dict = {}
    for v in gen_vars:
        # Key:    op.name in GAN Checkpoint file
        # Value:  Local generator Variable 
        gen_dict[v.op.name.replace('sampler/','')] = v
    gen_saver = tf.train.Saver(gen_dict)
    ckpt = tf.train.get_checkpoint_state(args.save_dir_GAN)
    if ckpt and ckpt.model_checkpoint_path:
        gen_saver.restore(sess, ckpt.model_checkpoint_path)

    for _ in xrange(num_samples / args.batch_size):
        samples.append(generator.generate_samples(sess, saved_args, chars, vocab, args.n))
    return samples
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号