def check_svm_model_equal(dense_svm, sparse_svm, X_train, y_train, X_test):
dense_svm.fit(X_train.toarray(), y_train)
if sparse.isspmatrix(X_test):
X_test_dense = X_test.toarray()
else:
X_test_dense = X_test
sparse_svm.fit(X_train, y_train)
assert_true(sparse.issparse(sparse_svm.support_vectors_))
assert_true(sparse.issparse(sparse_svm.dual_coef_))
assert_array_almost_equal(dense_svm.support_vectors_,
sparse_svm.support_vectors_.toarray())
assert_array_almost_equal(dense_svm.dual_coef_, sparse_svm.dual_coef_.toarray())
if dense_svm.kernel == "linear":
assert_true(sparse.issparse(sparse_svm.coef_))
assert_array_almost_equal(dense_svm.coef_, sparse_svm.coef_.toarray())
assert_array_almost_equal(dense_svm.support_, sparse_svm.support_)
assert_array_almost_equal(dense_svm.predict(X_test_dense), sparse_svm.predict(X_test))
assert_array_almost_equal(dense_svm.decision_function(X_test_dense),
sparse_svm.decision_function(X_test))
assert_array_almost_equal(dense_svm.decision_function(X_test_dense),
sparse_svm.decision_function(X_test_dense))
if isinstance(dense_svm, svm.OneClassSVM):
msg = "cannot use sparse input in 'OneClassSVM' trained on dense data"
else:
assert_array_almost_equal(dense_svm.predict_proba(X_test_dense),
sparse_svm.predict_proba(X_test), 4)
msg = "cannot use sparse input in 'SVC' trained on dense data"
if sparse.isspmatrix(X_test):
assert_raise_message(ValueError, msg, dense_svm.predict, X_test)
评论列表
文章目录