graph_actions_test.py 文件源码

python
阅读 61 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def test_train_loss(self):
    with tf.Graph().as_default() as g, self.test_session(g):
      tf.contrib.framework.create_global_step()
      loss_var = tf.contrib.framework.local_variable(10.0)
      train_op = tf.group(
          tf.assign_add(tf.contrib.framework.get_global_step(), 1),
          tf.assign_add(loss_var, -1.0))
      writer = learn.graph_actions.get_summary_writer(self._output_dir)
      self._assert_summaries(self._output_dir, writer)
      self._assert_ckpt(self._output_dir, False)
      loss = learn.graph_actions._monitored_train(  # pylint: disable=protected-access
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=loss_var.value(),
          steps=6)
      meta_graph_def = meta_graph.create_meta_graph_def()
      self.assertEqual(4.0, loss)
      self._assert_summaries(self._output_dir, writer, expected_graphs=[g],
                             expected_meta_graphs=[meta_graph_def])
      self._assert_ckpt(self._output_dir, True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号