def left_interp(interp_indices, interp_values, rhs):
is_vector = rhs.ndimension() == 1
if is_vector:
res = rhs.index_select(0, interp_indices.view(-1)).view(*interp_values.size())
res = res.mul(interp_values)
return res.sum(-1)
else:
interp_size = list(interp_indices.size()) + [rhs.size(-1)]
rhs_size = deepcopy(interp_size)
rhs_size[-3] = rhs.size()[-2]
interp_indices_expanded = interp_indices.unsqueeze(-1).expand(*interp_size)
res = rhs.unsqueeze(-2).expand(*rhs_size).gather(-3, interp_indices_expanded)
res = res.mul(interp_values.unsqueeze(-1).expand(interp_size))
return res.sum(-2)
评论列表
文章目录