models.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:hh-page-classifier 作者: TeamHG-Memex 项目源码 文件源码
def explain_predictions(self, docs, top=30):
        if not isinstance(self.clf, XGBClassifier):
            raise NotImplementedError
        booster = self.clf.booster()
        xgb_feature_names = {f: i for i, f in enumerate(booster.feature_names)}
        feature_names = get_feature_names(self.clf, self.vec,
                                          num_features=len(xgb_feature_names))
        feature_names.bias_name = '<BIAS>'
        X = self.vec.transform(docs)
        X = X.tocsc()
        dmatrix = DMatrix(X, missing=self.clf.missing)
        leaf_ids = booster.predict(dmatrix, pred_leaf=True)
        tree_dumps = booster.get_dump(with_stats=True)
        docs_weights = []
        for i, _leaf_ids in enumerate(leaf_ids):
            all_weights = _target_feature_weights(
                _leaf_ids, tree_dumps,
                feature_names=feature_names,
                xgb_feature_names=xgb_feature_names)[1]
            weights = np.zeros_like(all_weights)
            idx = X[i].nonzero()[1]
            bias_idx = feature_names.bias_idx
            weights[idx] = all_weights[idx]
            weights[bias_idx] = all_weights[bias_idx]
            docs_weights.append(weights)
        weights = np.mean(docs_weights, axis=0)
        feature_weights = get_top_features(
            feature_names=np.array(
                [_prettify_feature(f) for f in feature_names]),
            coef=weights,
            top=top)
        return Explanation(
            estimator=type(self.clf).__name__,
            targets=[TargetExplanation('y', feature_weights=feature_weights)],
        )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号