xgb_rank.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:kaggle-review 作者: daxiongshu 项目源码 文件源码
def fit(self,X,y,Xg,Xt=None,yt=None,Xgt=None,load_model=None,save_model=None):
        print(X.shape,y.shape)
        num_round = self.params['num_round']
        early_stopping_rounds = self.params['early_stopping_rounds']
        dtrain = xgb.DMatrix(X, y)
        dtrain.set_group(Xg)

        if Xt is not None:
            dvalid = xgb.DMatrix(Xt, yt)
            dvalid.set_group(Xgt)
            watchlist = [(dtrain, 'train'), (dvalid, 'valid')]
            bst = xgb.train(self.params, dtrain, num_round, evals = watchlist,
                early_stopping_rounds=early_stopping_rounds,verbose_eval=1,xgb_model=load_model,
                maximize=True)
        else:
            watchlist = [(dtrain, 'train')]
            bst = xgb.train(self.params, dtrain, num_round, evals = watchlist,
                verbose_eval=1,xgb_model=load_model)
        self.bst = bst
        if save_model is not None:
            bst.save_model(save_model)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号