def fit(self,
train_fs, train_labels,
valid_fs, valid_labels):
rank_k = self.config.getint('RANK', 'rank_k')
train_DMatrix = xgb.DMatrix(train_fs, label=train_labels)
train_DMatrix.set_group([rank_k] * (len(train_labels) / rank_k))
valid_DMatrix = xgb.DMatrix(valid_fs, label=valid_labels)
valid_DMatrix.set_group([rank_k] * (len(valid_labels) / rank_k))
watchlist = [(train_DMatrix, 'train'), (valid_DMatrix, 'valid')]
# self.__lock()
self.model = xgb.train(self.params,
train_DMatrix,
self.params['num_round'],
watchlist,
early_stopping_rounds=self.params['early_stop'],
verbose_eval=self.params['verbose_eval'])
LogUtil.log('INFO', 'best_ntree_limit=%d' % self.model.best_ntree_limit)
# self.__unlock()
valid_preds = self.model.predict(valid_DMatrix, ntree_limit=self.model.best_ntree_limit)
return valid_preds
model.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录