sparse.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def backward(ctx, grad_output):
        if ctx._indices is not None:
            indices = ctx._indices
        else:
            indices, = ctx.saved_tensors

        grad_output = grad_output.contiguous()
        if not ctx.sparse:
            if indices.dim() == 2:
                indices = indices.view(-1)

            with torch.cuda.device_of(grad_output):
                if grad_output.is_cuda:
                    _sorted = torch.cuda.LongTensor()
                    _indices = torch.cuda.LongTensor()
                    _count = torch.cuda.LongTensor()
                else:
                    _count = torch.IntTensor()
                    _sorted = _indices = None

            grad_weight = grad_output.new(ctx._weight_size).zero_()
            # Doesn't support Variable grad_output
            ctx._backend.LookupTable_accGradParameters(
                ctx._backend.library_state,
                indices,
                grad_output,
                grad_weight,
                _count,
                _sorted,
                _indices,
                ctx.scale_grad_by_freq,
                ctx.padding_idx,
                1
            )
        else:
            tensor_type = type(grad_output).__name__
            if grad_output.is_cuda:
                SparseTensor = getattr(torch.cuda.sparse, tensor_type)
            else:
                SparseTensor = getattr(torch.sparse, tensor_type)
            grad_weight = SparseTensor(
                indices.view(1, -1),
                grad_output.view(-1, ctx._weight_size[1]),
                ctx._weight_size,
            )
        return None, grad_weight, None, None, None, None, None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号