regression_cat.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Steal-ML 作者: ftramer 项目源码 文件源码
def __init__(self, X, y, multinomial, rounding=None):
        self.input_features = X.columns.values

        X = X.values

        cat_idx = [i for i in range(X.shape[1]) if min(X[:, i]) == 0]
        self.encoder = OneHotEncoder(categorical_features=cat_idx, sparse=False)

        X = self.encoder.fit_transform(X)

        self.features = range(X.shape[1])
        self.rounding = rounding

        # train a model on the whole dataset
        self.model = LogisticRegression()
        self.model.fit(X, y)

        self.w = self.model.coef_
        self.intercept = self.model.intercept_
        self.multinomial = multinomial
        assert not (multinomial and len(self.get_classes()) == 2)

        RegressionExtractor.__init__(self)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号