train.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:JData 作者: edvardHua 项目源码 文件源码
def xgboost_cv():
    # ????????
    train_start_date = '2016-02-15'
    train_end_date = '2016-03-15'
    # ?????????????????????
    test_start_date = '2016-03-16'
    test_end_date = '2016-03-20'

    # ??????????????
    # ??
    sub_start_date = '2016-03-21'
    sub_end_date = '2016-04-02'
    # ??
    sub_test_start_date = '2016-04-03'
    sub_test_end_date = '2016-04-08'

    user_index, training_data, label = make_train_set(train_start_date, train_end_date, test_start_date, test_end_date)
    # ???? ???????
    X_train, X_test, y_train, y_test = train_test_split(training_data, label, test_size=0.2, random_state=0)
    dtrain=xgb.DMatrix(X_train.values, label=y_train)
    dtest=xgb.DMatrix(X_test.values, label=y_test)
    param = {'max_depth': 10, 'eta': 0.05, 'silent': 1, 'objective': 'binary:logistic'}
    num_round = 166
    param['nthread'] = 5
    param['eval_metric'] = "auc"
    plst = param.items()
    evallist = [(dtest, 'eval'), (dtrain, 'train')]
    bst=xgb.train(plst, dtrain, num_round, evallist)

    sub_user_index, sub_trainning_data, sub_label = make_train_set(sub_start_date, sub_end_date,
                                                                   sub_test_start_date, sub_test_end_date)
    sub_trainning_data = xgb.DMatrix(sub_trainning_data.values)
    y = bst.predict(sub_trainning_data)

    y_mean = stats.describe(y).mean
    # plt.hist(y)
    # plt.show()

    pred = sub_user_index.copy()
    y_true = sub_user_index.copy()
    pred['label'] = y
    y_true['label'] = label

    pred = pred[pred['label'] >= 0.04]
    y_true = y_true[y_true['label'] == 1]

    report(pred, y_true)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号