sparse.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def backward(self, grad_output):
        if self._indices is not None:
            indices = self._indices
        else:
            indices, = self.saved_tensors

        if indices.dim() == 2:
            indices = indices.view(-1)

        grad_output = grad_output.contiguous()

        if torch.typename(grad_output) == 'torch.cuda.FloatTensor':
            _sorted = torch.cuda.LongTensor()
            _indices = torch.cuda.LongTensor()
            _count = torch.cuda.LongTensor()
        else:
            _count = torch.IntTensor()
            _sorted = _indices = None

        # TODO: sparse updates...
        grad_weight = type(grad_output)(self._weight_size).zero_()
        self._backend.LookupTable_accGradParameters(
            self._backend.library_state,
            indices,
            grad_output,
            grad_weight,
            _count,
            _sorted,
            _indices,
            self.scale_grad_by_freq,
            self.padding_idx,
            1
        )
        return None, grad_weight
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号