def test_model_tranining(self):
# test by running svm on digits
digits = datasets.load_digits()
images_and_labels = list(zip(digits.images, digits.target))
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
svm = SVM()
pipe = Pipeline(models={'SVM': svm})
pipe.train(data[:n_samples // 2], digits.target[:n_samples // 2])
assert svm.classifier is not None
expected = digits.target[n_samples // 2:]
predicted = pipe.predict(data[n_samples // 2:])
assert predicted['SVM'] is not None
评论列表
文章目录