matrix.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:paysage 作者: drckf 项目源码 文件源码
def index_select(mat: T.Tensor, index: T.LongTensor, dim: int = 0) -> T.Tensor:
    """
    Select the specified indices of a tensor along dimension dim.
    For example, dim = 1 is equivalent to mat[:, index] in numpy.

    Args:
        mat (tensor (num_samples, num_units))
        index (tensor; 1 -dimensional)
        dim (int)

    Returns:
        if dim == 0:
            mat[index, :]
        if dim == 1:
            mat[:, index]

    """
    return torch.index_select(mat, dim, index)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号