base.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:motion-classification 作者: matthiasplappert 项目源码 文件源码
def fit(self, X, y):
        assert isinstance(X, list)  #TODO: this should not be an assert
        assert len(y) > 0
        assert len(X) == len(y)

        # TODO: add support for fitting again after having already performed a fit
        self.n_labels_ = y.shape[1]
        self.models_ = []

        # Train one model per label. If no data is available for a given label, the model is set to None.
        models, data = [], []
        for idx in range(self.n_labels_):
            d = [X[i] for i in np.where(y[:, idx] == 1)[0]]
            if len(d) == 0:
                model = None
            else:
                model = clone(self.model)
            data.append(d)
            models.append(model)
        assert len(models) == len(data)
        n_jobs = self.n_jobs if self.model.supports_parallel() else 1
        self.models_ = Parallel(n_jobs=n_jobs)(delayed(_perform_fit)(models[i], data[i]) for i in range(len(models)))
        assert len(self.models_) == self.n_labels_
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号