basic_session_run_hooks_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def test_stop_based_on_last_step(self):
    h = basic_session_run_hooks.StopAtStepHook(last_step=10)
    with tf.Graph().as_default():
      global_step = tf.contrib.framework.get_or_create_global_step()
      no_op = tf.no_op()
      h.begin()
      with tf.Session() as sess:
        mon_sess = monitored_session._HookedSession(sess, [h])
        sess.run(tf.assign(global_step, 5))
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(tf.assign(global_step, 9))
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(tf.assign(global_step, 10))
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop())
        sess.run(tf.assign(global_step, 11))
        mon_sess._should_stop = False
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号