trainer.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pointer-network-tensorflow 作者: devsisters 项目源码 文件源码
def train(self):
    tf.logging.info("Training starts...")
    self.data_loader.run_input_queue(self.sess)

    summary_writer = None
    for k in trange(self.max_step, desc="train"):
      fetch = {
          'optim': self.model.optim,
      }
      result = self.model.train(self.sess, fetch, summary_writer)

      if result['step'] % self.log_step == 0:
        self._test(self.summary_writer)

      summary_writer = self._get_summary_writer(result)

    self.data_loader.stop_input_queue()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号