rand_forest.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:taxi 作者: xuguanggen 项目源码 文件源码
def run(result_csv_path):
    train_x,train_y = load_data(train_csv_path,True)
    test_x = load_data(test_csv_path,False)
    print('load data successfully ......')

    rf = RandomForestRegressor(
            n_estimators = 2000, #[1500,2000]
            min_samples_split = 2,
            max_depth = 15, # [10,15]
            n_jobs = -1
            )
    rf.fit(train_x,train_y)
    ###### save model ##################
    joblib.dump(rf,'weights/'+Model_Name+'.m')

    y_pred = rf.predict(test_x)


    ####### save_results ###########################
    save_results(result_csv_path,y_pred)

    ###### generate report #######################
    feature_importances = rf.feature_importances_
    dic_feature_importances = dict(zip(fields,feature_importances))
    dic = sorted(dic_feature_importances.iteritems(),key = lambda d:d[1],reverse = True)
    print('feature_importances:')
    for i in range(len(dic)):
        print(dic[i][0]+":\t"+str(dic[i][1]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号