model_gdbt.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:JData 作者: Xls1994 项目源码 文件源码
def GBDT_classify(train_dataSet_path, test_dataSet_path, train_one_and_two_result_as_proba_path):

    train_data = pd.read_csv(train_dataSet_path)
    train_data = train_data.as_matrix()
    X_train = train_data[:, 2:-1]  # select columns 0 through end-1
    y_train = train_data[:, -1]  # select column end

    test_data = pd.read_csv(test_dataSet_path)
    test_data = test_data.as_matrix()
    X_test = test_data[:, 2:-1]  # select columns 0 through end-1
    y_test = test_data[:, -1]  # select column end 

    clf = GradientBoostingClassifier(n_estimators=200)
    clf.fit(X_train, y_train)
    pre_y_test = clf.predict_proba(X_test)
    print pre_y_test
    print("GBDT Metrics : {0}".format(precision_recall_fscore_support(y_test, pre_y_test)))

    print u'????.....'
    f_result = open(test_dataSet_prob_path, 'w')
    for i in range(0, len(pre_y_test)):
        if i==0:
            print str(pre_y_test[i][0])
        if i==len(pre_y_test)-1:
            print str(pre_y_test[i][0])
        f_result.write(str(pre_y_test[i][0]) + '\n')   

    return clf
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号