sample.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:sequelspeare 作者: raidancampbell 项目源码 文件源码
def __init__(self, save_dir=SAVE_DIR, prime_text=PRIME_TEXT, num_sample_symbols=NUM_SAMPLE_SYMBOLS):
        self.save_dir = save_dir
        self.prime_text = prime_text
        self.num_sample_symbols = num_sample_symbols
        with open(os.path.join(Sampler.SAVE_DIR, 'chars_vocab.pkl'), 'rb') as file:
            self.chars, self.vocab = cPickle.load(file)
            self.model = Model(len(self.chars), is_sampled=True)

            # polite GPU memory allocation: don't grab everything you can.
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            config.gpu_options.allocator_type = 'BFC'
            self.sess = tf.Session(config=config)

            tf.initialize_all_variables().run(session=self.sess)
            self.checkpoint = tf.train.get_checkpoint_state(self.save_dir)
            if self.checkpoint and self.checkpoint.model_checkpoint_path:
                tf.train.Saver(tf.all_variables()).restore(self.sess, self.checkpoint.model_checkpoint_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号