def check_multiplication_dims(dims, N, M, vidx=False, without=False):
dims = array(dims, ndmin=1)
if len(dims) == 0:
dims = arange(N)
if without:
dims = setdiff1d(range(N), dims)
if not np.in1d(dims, arange(N)).all():
raise ValueError('Invalid dimensions')
P = len(dims)
sidx = np.argsort(dims)
sdims = dims[sidx]
if vidx:
if M > N:
raise ValueError('More multiplicants than dimensions')
if M != N and M != P:
raise ValueError('Invalid number of multiplicants')
if P == M:
vidx = sidx
else:
vidx = sdims
return sdims, vidx
else:
return sdims
评论列表
文章目录