def train_val_split(self, split_ratio=0.1):
"""Generates train and validation sets from the training indices.
Args:
split_ratio: The split proportion in [0, 1] (Default value: 0.1)
Returns:
The stratified train and val subsets. Multi-label outputs are handled as well.
"""
if self.is_multi_label:
train_indices, val_indices = sampling.multi_label_train_test_split(self.y, split_ratio)
else:
sss = StratifiedShuffleSplit(n_splits=1, test_size=split_ratio)
train_indices, val_indices = next(sss.split(self.X, self.y))
return self.X[train_indices], self.X[val_indices], self.y[train_indices], self.y[val_indices]
评论列表
文章目录