def test_unsupported(self):
"""Test handling of unsupported types."""
with self.assertRaises(ValueError):
StopOnNaN(on_unknown_type='error').after_epoch(epoch_data=StopOnNaNTest._get_data(lambda: 0))
with self.assertRaises(AssertionError):
StopOnNaN(on_unknown_type='bad value')
with LogCapture() as log_capture:
StopOnNaN(on_unknown_type='warn').after_epoch(epoch_data=StopOnNaNTest._get_data(lambda: 0))
log_capture.check(
('root', 'WARNING', 'Variable `var` of type `<class \'function\'>` can not be checked for NaNs.'),
)
StopOnNaN().after_epoch(epoch_data=StopOnNaNTest._get_data(lambda: 0))
评论列表
文章目录