monitored_session_test.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def test_regular_exception_pass_through_run(self):
    # Tests that regular exceptions just pass through a "with
    # MonitoredSession" block and set the session in stop mode.
    with tf.Graph().as_default():
      gstep = tf.contrib.framework.get_or_create_global_step()
      do_step = tf.assign_add(gstep, 1)
      hook = RaiseOnceAtCountN(4, RuntimeError('regular exception'))
      session = monitored_session.MonitoredSession(hooks=[hook])
      with self.assertRaisesRegexp(RuntimeError, 'regular exception'):
        with session:
          self.assertEqual(0, session.run(gstep))
          self.assertEqual(1, session.run(do_step))
          self.assertEqual(2, session.run(do_step))
          self.assertFalse(session.should_stop())
          # This triggers the hook and raises the exception
          session.run(do_step)
          # We should not hit this
          self.assertFalse(True)
      self.assertTrue(hook.raised)
      self.assertTrue(session.should_stop())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号