def prec_rf(n_trees, X_train, y_train, X_test, y_test):
"""
ExtraTrees
"""
from sklearn.ensemble import RandomForestClassifier
if not issparse(X_train):
X_train = X_train.reshape((X_train.shape[0], -1))
if not issparse(X_test):
X_test = X_test.reshape((X_test.shape[0], -1))
LOGGER.info('start predict: n_trees={},X_train.shape={},y_train.shape={},X_test.shape={},y_test.shape={}'.format(
n_trees, X_train.shape, y_train.shape, X_test.shape, y_test.shape))
clf = RandomForestClassifier(n_estimators=n_trees, max_depth=None, n_jobs=-1, verbose=1)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
prec = float(np.sum(y_pred == y_test)) / len(y_test)
LOGGER.info('prec_rf{}={:.6f}%'.format(n_trees, prec*100.0))
return clf, y_pred
评论列表
文章目录