dataset.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:skorch 作者: dnouri 项目源码 文件源码
def __call__(self, X, y):
        bad_y_error = ValueError("Stratified CV not possible with given y.")
        if (y is None) and self.stratified:
            raise bad_y_error

        cv = self.check_cv(y)
        if self.stratified and not self._is_stratified(cv):
            raise bad_y_error

        # pylint: disable=invalid-name
        len_X = get_len(X)
        if y is not None:
            len_y = get_len(y)
            if len_X != len_y:
                raise ValueError("Cannot perform a CV split if X and y "
                                 "have different lengths.")

        args = (np.arange(len_X),)
        if self._is_stratified(cv):
            args = args + (to_numpy(y),)

        idx_train, idx_valid = next(iter(cv.split(*args)))
        X_train = multi_indexing(X, idx_train)
        X_valid = multi_indexing(X, idx_valid)
        y_train = None if y is None else multi_indexing(y, idx_train)
        y_valid = None if y is None else multi_indexing(y, idx_valid)
        return X_train, X_valid, y_train, y_valid
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号