utils.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:seqmod 作者: emanjavacas 项目源码 文件源码
def select_cols(t, vec):
    """
    `vec[i]` contains the index of the column to pick from the ith row  in `t`

    Parameters
    ----------

    - t: torch.Tensor (m x n)
    - vec: list or torch.LongTensor of size equal to number of rows in t

    >>> x = torch.arange(0, 10).repeat(10, 1).t()  # [[0, ...], [1, ...], ...]
    >>> list(select_cols(x, list(range(10))))
    [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
    """
    if isinstance(vec, list):
        vec = torch.LongTensor(vec)
    return t.gather(1, vec.unsqueeze(1)).squeeze(1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号