write_tensorboard_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:cxflow-tensorflow 作者: Cognexa 项目源码 文件源码
def test_unknown_type(self):
        """Test if ``WriteTensorBoard`` handles unknown variable types as expected."""
        bad_epoch_data = {'valid': {'accuracy': 'bad_type'}}

        # test ignore
        hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model())
        with LogCapture(level=logging.INFO) as log_capture:
            hook.after_epoch(42, bad_epoch_data)
        log_capture.check()

        # test warn
        warn_hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model(), on_unknown_type='warn')
        with LogCapture(level=logging.INFO) as log_capture2:
            warn_hook.after_epoch(42, bad_epoch_data)
        log_capture2.check(('root', 'WARNING', 'Variable `accuracy` in stream `valid` has to be of type `int` '
                                               'or `float` (or a `dict` with a key named `mean` or `nanmean` '
                                               'whose corresponding value is of type `int` or `float`), '
                                               'found `<class \'str\'>` instead.'))

        # test error
        raise_hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model(), on_unknown_type='error')
        with self.assertRaises(ValueError):
            raise_hook.after_epoch(42, bad_epoch_data)

        with mock.patch.dict('sys.modules', **{'cv2': cv2_mock}):
            # test skip image variables
            skip_hook = WriteTensorBoard(output_dir=self.tmpdir, model=self.get_model(), on_unknown_type='error',
                                         image_variables=['accuracy'])
            skip_hook.after_epoch(42, {'valid': {'accuracy': np.zeros((10, 10, 3))}})
            skip_hook._summary_writer.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号