def test_support_vector_classifier(self):
for dtype in self.number_data_type.keys():
scikit_model = SVC(kernel='rbf', gamma=1.2, C=1)
data = self.scikit_data['data'].astype(dtype)
target = self.scikit_data['target'].astype(dtype) > self.scikit_data['target'].astype(dtype).mean()
scikit_model, spec = self._sklearn_setup(scikit_model, dtype, data, target)
coreml_model = create_model(spec)
for idx in range(0, 10):
test_data = data[idx].reshape(1, -1)
try:
self.assertEqual(scikit_model.predict(test_data)[0],
bool(int(coreml_model.predict({'data': test_data})['target'])),
msg="{} != {} for Dtype: {}".format(
scikit_model.predict(test_data)[0],
bool(int(coreml_model.predict({'data': test_data})['target'])),
dtype
)
)
except RuntimeError:
print("{} not supported. ".format(dtype))
评论列表
文章目录