def _validate_axis(axis, ndim, argname):
try:
axis = [operator.index(axis)]
except TypeError:
axis = list(axis)
axis = [a + ndim if a < 0 else a for a in axis]
if not builtins.all(0 <= a < ndim for a in axis):
raise ValueError('invalid axis for this array in `%s` argument' %
argname)
if len(set(axis)) != len(axis):
raise ValueError('repeated axis in `%s` argument' % argname)
return axis
评论列表
文章目录