def test_regular_exception_pass_through_run(self):
# Tests that regular exceptions just pass through a "with
# MonitoredSession" block and set the session in stop mode.
with tf.Graph().as_default():
gstep = tf.contrib.framework.get_or_create_global_step()
do_step = tf.assign_add(gstep, 1)
hook = RaiseOnceAtCountN(4, RuntimeError('regular exception'))
session = monitored_session.MonitoredSession(hooks=[hook])
with self.assertRaisesRegexp(RuntimeError, 'regular exception'):
with session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# This triggers the hook and raises the exception
session.run(do_step)
# We should not hit this
self.assertFalse(True)
self.assertTrue(hook.raised)
self.assertTrue(session.should_stop())
评论列表
文章目录