sparse.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def forward(cls, ctx, indices, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq,
                sparse=False):

        ctx.padding_idx = padding_idx
        ctx.scale_grad_by_freq = scale_grad_by_freq
        ctx._indices = None
        ctx.sparse = sparse

        assert indices.dim() <= 2
        assert not ctx.needs_input_grad[0], "Embedding doesn't " \
            "compute the gradient w.r.t. the indices"

        ctx._backend = type2backend[type(weight)]
        ctx._weight_size = weight.size()

        if not indices.is_contiguous():
            ctx._indices = indices.contiguous()
            indices = ctx._indices
        else:
            ctx.save_for_backward(indices)

        output = weight.new()
        if max_norm is not None:
            cls._renorm(ctx, indices, weight, max_norm, norm_type)

        if indices.dim() == 1:
            output = torch.index_select(weight, 0, indices)
        else:
            output = torch.index_select(weight, 0, indices.view(-1))
            output = output.view(indices.size(0), indices.size(1), weight.size(1))

        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号