def split(ary, indices_or_sections, axis=0):
"""Splits an array into multiple sub arrays along a given axis.
Args:
ary (cupy.ndarray): Array to split.
indices_or_sections (int or sequence of ints): A value indicating how
to divide the axis. If it is an integer, then is treated as the
number of sections, and the axis is evenly divided. Otherwise,
the integers indicate indices to split at. Note that the sequence
on the device memory is not allowed.
axis (int): Axis along which the array is split.
Returns:
A list of sub arrays. Each array is a view of the corresponding input
array.
.. seealso:: :func:`numpy.split`
"""
if ary.ndim <= axis:
raise IndexError('Axis exceeds ndim')
size = ary.shape[axis]
if numpy.isscalar(indices_or_sections):
if size % indices_or_sections != 0:
raise ValueError(
'indices_or_sections must divide the size along the axes.\n'
'If you want to split the array into non-equally-sized '
'arrays, use array_split instead.')
return array_split(ary, indices_or_sections, axis)
评论列表
文章目录