xg_train.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:trend_ml_toolkit_xgboost 作者: raymon-tian 项目源码 文件源码
def tune_num_boost_round():
    # global watchlist
    global num_boost_round
    global evals_result
    global eval_metric_xgb_format
    evals_result = {}
    xgb.train(params=params_no_sklearn,dtrain=dtrain,num_boost_round=num_boost_round,evals=watchlist,evals_result=evals_result)
    evals_result = evals_result['eval'][eval_metric_xgb_format]
    # pprint.pprint(evals_result)
    max = 0.0
    max_loc = 0
    for i,v in enumerate(evals_result):
        # print '%d ...... %d : %d'%(i,max_loc,max)
        if v>max:
            max = v
            max_loc = i
    # print "max_loc : %s ,  max : %s"%(max_loc,max)
    num_boost_round = max_loc+1
    print('****  num_boost_round : ', num_boost_round)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号