def load_dataset(path_train, n_features, path_valid=None, path_test=None, multilabel=False, classes_=None):
le = LabelEncoder2(multilabel=multilabel)
if path_valid is None and path_test is None: # TODO zero_based=True?
X, Y = load_svmlight_file(path_train, dtype=np.float32, n_features=n_features, multilabel=multilabel)
if classes_ is None:
le.fit(Y)
Y = le.transform(Y)
else:
le.set_classes(classes_)
Y = le.transform(Y)
return X, Y, None, None, le
elif path_test is None:
X, Y, Xvalid, Yvalid = load_svmlight_files((path_train, path_valid), dtype=np.float32,
n_features=n_features,
multilabel=multilabel)
if classes_ is None:
le.fit(np.concatenate((Y, Yvalid), axis=0))
Y = le.transform(Y)
Yvalid = le.transform(Yvalid)
else:
le.set_classes(classes_)
Y = le.transform(Y)
Yvalid = le.transform(Yvalid)
return X, Y, Xvalid, Yvalid, le
else:
X, Y, Xvalid, Yvalid, Xtest, Ytest = load_svmlight_files((path_train, path_valid, path_test), dtype=np.float32,
n_features=n_features,
multilabel=multilabel)
if classes_ is None:
le.fit(np.concatenate((Y, Yvalid, Ytest), axis=0))
Y = le.transform(Y)
Yvalid = le.transform(Yvalid)
Ytest = le.transform(Ytest)
else:
le.set_classes(classes_)
Y = le.transform(Y)
Yvalid = le.transform(Yvalid)
return X, Y, Xvalid, Yvalid, Xtest, Ytest, le
评论列表
文章目录