split.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:cupy 作者: cupy 项目源码 文件源码
def split(ary, indices_or_sections, axis=0):
    """Splits an array into multiple sub arrays along a given axis.

    Args:
        ary (cupy.ndarray): Array to split.
        indices_or_sections (int or sequence of ints): A value indicating how
            to divide the axis. If it is an integer, then is treated as the
            number of sections, and the axis is evenly divided. Otherwise,
            the integers indicate indices to split at. Note that the sequence
            on the device memory is not allowed.
        axis (int): Axis along which the array is split.

    Returns:
        A list of sub arrays. Each array is a view of the corresponding input
        array.

    .. seealso:: :func:`numpy.split`

    """
    if ary.ndim <= axis:
        raise IndexError('Axis exceeds ndim')
    size = ary.shape[axis]

    if numpy.isscalar(indices_or_sections):
        if size % indices_or_sections != 0:
            raise ValueError(
                'indices_or_sections must divide the size along the axes.\n'
                'If you want to split the array into non-equally-sized '
                'arrays, use array_split instead.')
    return array_split(ary, indices_or_sections, axis)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号