def test_serialization():
# Check model serialization.
clf = GradientBoostingClassifier(n_estimators=100, random_state=1)
clf.fit(X, y)
assert_array_equal(clf.predict(T), true_result)
assert_equal(100, len(clf.estimators_))
try:
import cPickle as pickle
except ImportError:
import pickle
serialized_clf = pickle.dumps(clf, protocol=pickle.HIGHEST_PROTOCOL)
clf = None
clf = pickle.loads(serialized_clf)
assert_array_equal(clf.predict(T), true_result)
assert_equal(100, len(clf.estimators_))
评论列表
文章目录