classifier_utils.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:human-rl 作者: gsastry 项目源码 文件源码
def run_forests():    
    print('random forest: \n')   
    params = []
    scores = []

    for _ in range(5):
        max_features = np.random.randint(400,800)
        max_depth = np.random.choice([None, None, None, None, 30, 40, 60])
        forest = RandomForestClassifier(n_estimators=50,
                                        max_features=max_features,
                                        max_depth=max_depth)                                   
        forest_fit = forest.fit(X_train, Y_train)
        pred = forest_fit.predict(X_test)
        print('\n params:', dict(max_features=max_features, max_depth=max_depth))
        print('forest train: ',zero_one_score(Y_train, forest_fit.predict(X_train)), ' test: ',
                  zero_one_score(Y_test, pred))

        params.append( (max_features, max_depth) )
        scores.append( zero_one_score(Y_test, pred))

    print('best:', params[np.argmin(scores)])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号