graph_actions_test.py 文件源码

python
阅读 48 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def test_train_max_steps_is_not_incremental(self):
    with tf.Graph().as_default() as g, self.test_session(g):
      with tf.control_dependencies(self._build_inference_graph()):
        train_op = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
      learn.graph_actions.train(g, output_dir=self._output_dir,
                                train_op=train_op, loss_op=tf.constant(2.0),
                                max_steps=10)
      step = tf.contrib.framework.load_variable(
          self._output_dir, tf.contrib.framework.get_global_step().name)
      self.assertEqual(10, step)

    with tf.Graph().as_default() as g, self.test_session(g):
      with tf.control_dependencies(self._build_inference_graph()):
        train_op = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
      learn.graph_actions.train(g, output_dir=self._output_dir,
                                train_op=train_op, loss_op=tf.constant(2.0),
                                max_steps=15)
      step = tf.contrib.framework.load_variable(
          self._output_dir, tf.contrib.framework.get_global_step().name)
      self.assertEqual(15, step)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号