def test_retry_on_aborted_error(self):
# Tests that we silently retry on abort. Note that this does not test
# recovery as we do not use a CheckpointSaver in this test.
with tf.Graph().as_default():
gstep = tf.contrib.framework.get_or_create_global_step()
do_step = tf.assign_add(gstep, 1)
hook = RaiseOnceAtCountN(4, tf.errors.AbortedError(None, None, 'Abort'))
with monitored_session.MonitoredSession(hooks=[hook]) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# Here at step 3, the hook triggers and raises AbortedError. The
# MonitoredSession automatically retries and restart from a freshly
# initialized session, so the step is back to 0 and running do_step
# moves it to 1.
self.assertEqual(1, session.run(do_step))
self.assertFalse(session.should_stop())
self.assertTrue(hook.raised)
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
评论列表
文章目录