def main_loop(self):
# some final operations that might modify the graph
self._init_summary()
get_global_step_var() # ensure there is such var, before finalizing the graph
logger.info("Setup callbacks ...")
callbacks = self.config.callbacks
callbacks.setup_graph(self) # TODO use weakref instead?
logger.info("Initializing graph variables ...")
self.sess.run(tf.initialize_all_variables())
self.config.session_init.init(self.sess)
tf.get_default_graph().finalize()
self._start_concurrency()
with self.sess.as_default():
try:
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train()
for epoch in range(self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, self.global_step + self.config.step_per_epoch)):
for step in tqdm.trange(
self.config.step_per_epoch,
**get_tqdm_kwargs(leave=True)):
if self.coord.should_stop():
return
self.run_step()
#callbacks.trigger_step() # not useful?
self.global_step += 1
self.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
# Do I need to run queue.close?
callbacks.after_train()
self.coord.request_stop()
self.summary_writer.close()
self.sess.close()
评论列表
文章目录