def fit(self, X, y):
assert isinstance(X, list) #TODO: this should not be an assert
assert len(y) > 0
assert len(X) == len(y)
# TODO: add support for fitting again after having already performed a fit
self.n_labels_ = y.shape[1]
self.models_ = []
# Train one model per label. If no data is available for a given label, the model is set to None.
models, data = [], []
for idx in range(self.n_labels_):
d = [X[i] for i in np.where(y[:, idx] == 1)[0]]
if len(d) == 0:
model = None
else:
model = clone(self.model)
data.append(d)
models.append(model)
assert len(models) == len(data)
n_jobs = self.n_jobs if self.model.supports_parallel() else 1
self.models_ = Parallel(n_jobs=n_jobs)(delayed(_perform_fit)(models[i], data[i]) for i in range(len(models)))
assert len(self.models_) == self.n_labels_
评论列表
文章目录