dataset.py 文件源码

python
阅读 61 收藏 0 点赞 0 评论 0

项目:heamy 作者: rushter 项目源码 文件源码
def kfold(self, k=5, stratify=False, shuffle=True, seed=33):
        """K-Folds cross validation iterator.

        Parameters
        ----------
        k : int, default 5
        stratify : bool, default False
        shuffle : bool, default True
        seed : int, default 33

        Yields
        -------
        X_train, y_train, X_test, y_test, train_index, test_index
        """
        if stratify:
            kf = StratifiedKFold(n_splits=k, random_state=seed, shuffle=shuffle)
        else:
            kf = KFold(n_splits=k, random_state=seed, shuffle=shuffle)

        for train_index, test_index in kf.split(self.X_train, self.y_train):
            X_train, y_train = idx(self.X_train, train_index), self.y_train[train_index]
            X_test, y_test = idx(self.X_train, test_index), self.y_train[test_index]
            yield X_train, y_train, X_test, y_test, train_index, test_index
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号