def index_select(mat: T.Tensor, index: T.LongTensor, dim: int = 0) -> T.Tensor:
"""
Select the specified indices of a tensor along dimension dim.
For example, dim = 1 is equivalent to mat[:, index] in numpy.
Args:
mat (tensor (num_samples, num_units))
index (tensor; 1 -dimensional)
dim (int)
Returns:
if dim == 0:
mat[index, :]
if dim == 1:
mat[:, index]
"""
return torch.index_select(mat, dim, index)
评论列表
文章目录