base.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:ternarynet 作者: czhu95 项目源码 文件源码
def main_loop(self):
        # some final operations that might modify the graph
        self._init_summary()
        get_global_step_var()   # ensure there is such var, before finalizing the graph
        logger.info("Setup callbacks ...")
        callbacks = self.config.callbacks
        callbacks.setup_graph(self) # TODO use weakref instead?
        logger.info("Initializing graph variables ...")
        self.sess.run(tf.initialize_all_variables())
        self.config.session_init.init(self.sess)
        tf.get_default_graph().finalize()
        self._start_concurrency()

        with self.sess.as_default():
            try:
                self.global_step = get_global_step()
                logger.info("Start training with global_step={}".format(self.global_step))

                callbacks.before_train()
                for epoch in range(self.config.starting_epoch, self.config.max_epoch+1):
                    with timed_operation(
                        'Epoch {}, global_step={}'.format(
                            epoch, self.global_step + self.config.step_per_epoch)):
                        for step in tqdm.trange(
                                self.config.step_per_epoch,
                                **get_tqdm_kwargs(leave=True)):
                            if self.coord.should_stop():
                                return
                            self.run_step()
                            #callbacks.trigger_step()   # not useful?
                            self.global_step += 1
                        self.trigger_epoch()
            except (KeyboardInterrupt, Exception):
                raise
            finally:
                # Do I need to run queue.close?
                callbacks.after_train()
                self.coord.request_stop()
                self.summary_writer.close()
                self.sess.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号