def select_cols(t, vec):
"""
`vec[i]` contains the index of the column to pick from the ith row in `t`
Parameters
----------
- t: torch.Tensor (m x n)
- vec: list or torch.LongTensor of size equal to number of rows in t
>>> x = torch.arange(0, 10).repeat(10, 1).t() # [[0, ...], [1, ...], ...]
>>> list(select_cols(x, list(range(10))))
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
"""
if isinstance(vec, list):
vec = torch.LongTensor(vec)
return t.gather(1, vec.unsqueeze(1)).squeeze(1)
评论列表
文章目录