def train_test_split(X, y, test_size=0.25, random_state=42, stratify=True):
if stratify:
n_folds = int(round(1 / test_size))
sss = StratifiedKFold(y, n_folds=n_folds, random_state=random_state)
else:
sss = ShuffleSplit(len(y), test_size=test_size, random_state=random_state)
train_idx, test_idx = iter(sss).next()
return X[train_idx], X[test_idx], y[train_idx], y[test_idx]
评论列表
文章目录