def test_unknown_type(self):
"""Test if ``WriteTensorBoard`` handles unknown variable types as expected."""
bad_epoch_data = {'valid': {'accuracy': 'bad_type'}}
# test ignore
hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model())
with LogCapture(level=logging.INFO) as log_capture:
hook.after_epoch(42, bad_epoch_data)
log_capture.check()
# test warn
warn_hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model(), on_unknown_type='warn')
with LogCapture(level=logging.INFO) as log_capture2:
warn_hook.after_epoch(42, bad_epoch_data)
log_capture2.check(('root', 'WARNING', 'Variable `accuracy` in stream `valid` has to be of type `int` '
'or `float` (or a `dict` with a key named `mean` or `nanmean` '
'whose corresponding value is of type `int` or `float`), '
'found `<class \'str\'>` instead.'))
# test error
raise_hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model(), on_unknown_type='error')
with self.assertRaises(ValueError):
raise_hook.after_epoch(42, bad_epoch_data)
with mock.patch.dict('sys.modules', **{'cv2': cv2_mock}):
# test skip image variables
skip_hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model(), on_unknown_type='error',
image_variables=['accuracy'])
skip_hook.after_epoch(42, {'valid': {'accuracy': np.zeros((10, 10, 3))}})
skip_hook._summary_writer.close()
write_tensorboard_test.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录