sparse.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def forward(self, indices, weight):
        assert indices.dim() <= 2
        assert not self.needs_input_grad[0], "Embedding doesn't " \
            "compute the gradient w.r.t. the indices"
        self._backend = type2backend[type(weight)]
        self._weight_size = weight.size()

        if not indices.is_contiguous():
            self._indices = indices.contiguous()
            indices = self._indices
        else:
            self.save_for_backward(indices)

        if self.max_norm is not None:
            self._renorm(indices, weight)

        output = weight.new()
        if indices.dim() == 1:
            torch.index_select(output, weight, 0, indices)
        else:
            torch.index_select(output, weight, 0, indices.view(-1))
            output = output.view(indices.size(0), indices.size(1), weight.size(1))

        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号