def __init__(self, save_dir=SAVE_DIR, prime_text=PRIME_TEXT, num_sample_symbols=NUM_SAMPLE_SYMBOLS):
self.save_dir = save_dir
self.prime_text = prime_text
self.num_sample_symbols = num_sample_symbols
with open(os.path.join(Sampler.SAVE_DIR, 'chars_vocab.pkl'), 'rb') as file:
self.chars, self.vocab = cPickle.load(file)
self.model = Model(len(self.chars), is_sampled=True)
# polite GPU memory allocation: don't grab everything you can.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.allocator_type = 'BFC'
self.sess = tf.Session(config=config)
tf.initialize_all_variables().run(session=self.sess)
self.checkpoint = tf.train.get_checkpoint_state(self.save_dir)
if self.checkpoint and self.checkpoint.model_checkpoint_path:
tf.train.Saver(tf.all_variables()).restore(self.sess, self.checkpoint.model_checkpoint_path)
评论列表
文章目录