utils.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:main 作者: rmkemker 项目源码 文件源码
def train_test_split_per_class(X, y, train_size=None, test_size=None):

    sh = np.array(X.shape)

    num_classes = len(np.bincount(y))

    sh[0] = 0
    X_train_arr =  np.zeros(sh, dtype=X.dtype)
    X_test_arr = np.zeros(sh, dtype=X.dtype)
    y_train_arr = np.zeros((0), dtype=y.dtype)
    y_test_arr = np.zeros((0), dtype=y.dtype)

    for i in range(num_classes):
        X_train, X_test, y_train, y_test = train_test_split(X[y==i], y[y==i],
                                                            train_size=train_size,
                                                            test_size=test_size)

        X_train_arr =  np.append(X_train_arr, X_train, axis=0)
        X_test_arr = np.append(X_test_arr, X_test, axis=0)
        y_train_arr = np.append(y_train_arr, y_train)
        y_test_arr = np.append(y_test_arr, y_test)

    return X_train_arr, X_test_arr, y_train_arr, y_test_arr
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号