train.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:JData 作者: Xls1994 项目源码 文件源码
def xgboost_make_submission(retrain = False):
    sub_start_date = '2016-03-15'
    sub_end_date = '2016-04-16'
    if os.path.exists('./cache/bstmodel.bin') and not retrain:
        bst = xgb.Booster({'ntheard':4})
        bst.load_model('./cache/bstmodel.bin')
    else:
        bst = xgboost_train()
    sub_user_index, sub_trainning_data = make_test_set(sub_start_date, sub_end_date, )
    sub_trainning_data = xgb.DMatrix(sub_trainning_data.values)
    y = bst.predict(sub_trainning_data)
    sub_user_index['label'] = y
    pred = sub_user_index[sub_user_index['label'] >= 0.03]
    pred = pred[['user_id', 'sku_id']]
    pred = pred.groupby('user_id').first().reset_index()
    pred['user_id'] = pred['user_id'].astype(int)
    dt = datetime.datetime.now()
    sdt = str(dt.date())+str(dt.hour)+str(dt.minute)+str(dt.second)
    pred.to_csv('./sub/submission_%s.csv' % sdt, index=False, index_label=False)
    # P = get_sku_ids_in_P()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号