test_io_types.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:coremltools 作者: apple 项目源码 文件源码
def test_random_forest_regressor(self):
        for dtype in self.number_data_type.keys():
            scikit_model = RandomForestRegressor(random_state=1)
            data = self.scikit_data['data'].astype(dtype)
            target = self.scikit_data['target'].astype(dtype)
            scikit_model, spec = self._sklearn_setup(scikit_model, dtype, data, target)
            test_data = data[0].reshape(1, -1)
            self._check_tree_model(spec, 'multiArrayType', 'doubleType', 1)
            coreml_model = create_model(spec)
            try:
                self.assertEqual(scikit_model.predict(test_data)[0].dtype,
                                 type(coreml_model.predict({'data': test_data})['target']))
                self.assertAlmostEqual(scikit_model.predict(test_data)[0],
                                       coreml_model.predict({'data': test_data})['target'],
                                       msg="{} != {} for Dtype: {}".format(
                                           scikit_model.predict(test_data)[0],
                                           coreml_model.predict({'data': test_data})['target'],
                                           dtype
                                       )
                                       )
            except RuntimeError:
                print("{} not supported. ".format(dtype))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号