def feature_importance(rf, save_path=None):
"""Plots feature importance of sklearn RandomForest
Parameters
----------
rf : RandomForestClassifier
save_path : str
"""
importances = rf.feature_importances_
nb = len(importances)
tree_imp = [tree.feature_importances_ for tree in rf.estimators_]
# print "Print feature importance of rf with %d trees." % len(tree_imp)
std = np.std(tree_imp, axis=0) / np.sqrt(len(tree_imp))
indices = np.argsort(importances)[::-1]
# Print the feature ranking
# print("Feature ranking:")
# for f in range(nb):
# print("%d. feature %d (%f)" %
# (f + 1, indices[f], importances[indices[f]]))
# Plot the feature importances of the forest
pl.figure()
pl.title("Feature importances")
pl.bar(range(nb), importances[indices],
color="r", yerr=std[indices], align="center")
pl.xticks(range(nb), indices)
pl.xlim([-1, nb])
if save_path is not None:
pl.savefig(save_path)
pl.close()
评论列表
文章目录