def train_test_split_per_class(X, y, train_size=None, test_size=None):
sh = np.array(X.shape)
num_classes = len(np.bincount(y))
sh[0] = 0
X_train_arr = np.zeros(sh, dtype=X.dtype)
X_test_arr = np.zeros(sh, dtype=X.dtype)
y_train_arr = np.zeros((0), dtype=y.dtype)
y_test_arr = np.zeros((0), dtype=y.dtype)
for i in range(num_classes):
X_train, X_test, y_train, y_test = train_test_split(X[y==i], y[y==i],
train_size=train_size,
test_size=test_size)
X_train_arr = np.append(X_train_arr, X_train, axis=0)
X_test_arr = np.append(X_test_arr, X_test, axis=0)
y_train_arr = np.append(y_train_arr, y_train)
y_test_arr = np.append(y_test_arr, y_test)
return X_train_arr, X_test_arr, y_train_arr, y_test_arr
评论列表
文章目录