def train(self):
self.training = True
util.log('Creating session and loading checkpoint')
session = tf.train.MonitoredTrainingSession(
checkpoint_dir=self.config.run_dir,
save_summaries_steps=0, # Summaries will be saved with train_op only
config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
with session:
if len(self.agents) == 1:
self.train_agent(session, self.agents[0])
else:
self.train_threaded(session)
util.log('Training complete')
评论列表
文章目录