def local_subtensor_of_dot(node):
"""
This optimization translates T.dot(A, B)[idxs] into T.dot(A[idxs_a], B[idxs_b]),
where idxs_a and idxs_b are defined appropriately.
idxs_a is the first A.ndim-1 entries of idxs,
and idxs_b is the remaining entries of idxs (if any),
modified to skip the second-to-last dimension of B
(because dot sums over this dimension).
"""
if not isinstance(node.op, Subtensor):
return
if (not node.inputs[0].owner or
not isinstance(node.inputs[0].owner.op, T.Dot)):
return
# If there is other node that use the outputs of the dot
# We don't want to compute twice the sub part.
if len(node.inputs[0].clients) > 1:
return
a = node.inputs[0].owner.inputs[0]
b = node.inputs[0].owner.inputs[1]
idx_list = get_idx_list(node.inputs, node.op.idx_list)
num_a_indices = min(a.ndim - 1, len(idx_list))
a_indices = idx_list[:num_a_indices]
b_indices = idx_list[num_a_indices:]
# This is necessary because np.dot sums the last index of a with the second to last of b
# so we want to skip the second-to-last index into b.
# This wasn't necessary for a, because we just ommitted the last index.
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if b.ndim > 1 and len(b_indices) >= b.ndim - 1:
b_indices = (b_indices[:b.ndim - 2] +
(slice(None, None, None),) + b_indices[b.ndim - 2:])
a_sub = a.__getitem__(tuple(a_indices))
b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b
# Copy over previous output stacktrace to a_sub and b_sub,
# because an error in the subtensor operation (e.g. an index error)
# on either a or b must correspond to an error in the
# subtensor operation on their dot product.
copy_stack_trace(node.outputs[0], [a_sub, b_sub])
# Copy over previous output stacktrace and previous dot product stacktrace,
# because an error here may correspond to an either in either the original
# dot product, or in the dot product after the subtensor operation.
r = T.dot(a_sub, b_sub)
copy_stack_trace([node.outputs[0], node.inputs[0]], r)
return [r]
评论列表
文章目录