def get_idx_list(inputs, idx_list, get_count=False):
"""
Given a list of inputs to the subtensor and its idx_list reorders
the inputs according to the idx list to get the right values.
If get_counts=True, instead returns the number of inputs consumed
during this process.
"""
# The number of indices
n = len(inputs) - 1
# The subtensor (or idx_list) does not depend on the inputs.
if n == 0:
return tuple(idx_list)
indices = list(reversed(list(inputs[1:])))
# General case
def convert(entry):
if isinstance(entry, gof.Type):
return indices.pop()
elif isinstance(entry, slice):
return slice(convert(entry.start),
convert(entry.stop),
convert(entry.step))
else:
return entry
cdata = tuple(map(convert, idx_list))
if get_count:
return n - len(indices)
else:
return cdata
评论列表
文章目录