def test_stop_based_on_last_step(self):
h = basic_session_run_hooks.StopAtStepHook(last_step=10)
with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step()
no_op = tf.no_op()
h.begin()
with tf.Session() as sess:
mon_sess = monitored_session._HookedSession(sess, [h])
sess.run(tf.assign(global_step, 5))
mon_sess.run(no_op)
self.assertFalse(mon_sess.should_stop())
sess.run(tf.assign(global_step, 9))
mon_sess.run(no_op)
self.assertFalse(mon_sess.should_stop())
sess.run(tf.assign(global_step, 10))
mon_sess.run(no_op)
self.assertTrue(mon_sess.should_stop())
sess.run(tf.assign(global_step, 11))
mon_sess._should_stop = False
mon_sess.run(no_op)
self.assertTrue(mon_sess.should_stop())
评论列表
文章目录