io.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:ltls 作者: kjasinska 项目源码 文件源码
def load_dataset(path_train, n_features, path_valid=None, path_test=None, multilabel=False, classes_=None):
    le = LabelEncoder2(multilabel=multilabel)
    if path_valid is None and path_test is None:  # TODO zero_based=True?
        X, Y = load_svmlight_file(path_train, dtype=np.float32, n_features=n_features, multilabel=multilabel)
        if classes_ is None:
            le.fit(Y)
            Y = le.transform(Y)
        else:
            le.set_classes(classes_)
            Y = le.transform(Y)
        return X, Y, None, None, le
    elif path_test is None:
        X, Y, Xvalid, Yvalid = load_svmlight_files((path_train, path_valid), dtype=np.float32,
                                                   n_features=n_features,
                                                   multilabel=multilabel)
        if classes_ is None:
            le.fit(np.concatenate((Y, Yvalid), axis=0))
            Y = le.transform(Y)
            Yvalid = le.transform(Yvalid)
        else:
            le.set_classes(classes_)
            Y = le.transform(Y)
            Yvalid = le.transform(Yvalid)
        return X, Y, Xvalid, Yvalid, le

    else:
        X, Y, Xvalid, Yvalid, Xtest, Ytest = load_svmlight_files((path_train, path_valid, path_test), dtype=np.float32,
                                                                 n_features=n_features,
                                                                 multilabel=multilabel)
        if classes_ is None:
            le.fit(np.concatenate((Y, Yvalid, Ytest), axis=0))
            Y = le.transform(Y)
            Yvalid = le.transform(Yvalid)
            Ytest = le.transform(Ytest)
        else:
            le.set_classes(classes_)
            Y = le.transform(Y)
            Yvalid = le.transform(Yvalid)
        return X, Y, Xvalid, Yvalid, Xtest, Ytest, le
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号