model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:zhihu-machine-learning-challenge-2017 作者: HouJP 项目源码 文件源码
def fit(self,
            train_fs, train_labels,
            valid_fs, valid_labels):
        rank_k = self.config.getint('RANK', 'rank_k')

        train_DMatrix = xgb.DMatrix(train_fs, label=train_labels)
        train_DMatrix.set_group([rank_k] * (len(train_labels) / rank_k))
        valid_DMatrix = xgb.DMatrix(valid_fs, label=valid_labels)
        valid_DMatrix.set_group([rank_k] * (len(valid_labels) / rank_k))

        watchlist = [(train_DMatrix, 'train'), (valid_DMatrix, 'valid')]
        # self.__lock()
        self.model = xgb.train(self.params,
                               train_DMatrix,
                               self.params['num_round'],
                               watchlist,
                               early_stopping_rounds=self.params['early_stop'],
                               verbose_eval=self.params['verbose_eval'])
        LogUtil.log('INFO', 'best_ntree_limit=%d' % self.model.best_ntree_limit)
        # self.__unlock()
        valid_preds = self.model.predict(valid_DMatrix, ntree_limit=self.model.best_ntree_limit)
        return valid_preds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号