sparse.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def forward(self, input, offsets=None):
        if input.dim() == 2:
            if offsets is not None:
                raise ValueError("if input is 2D, then offsets has to be None"
                                 ", as input is treated is a mini-batch of"
                                 " fixed length sequences. However, found "
                                 "offsets of type {}".format(type(offsets)))
            else:
                offsets = Variable(torch.arange(0, input.numel(), input.size(1),
                                   out=input.data.new().long()))
                input = input.view(-1)
        elif input.dim() != 1:
            raise ValueError("input has to be 1D or 2D Tensor,"
                             " but got Tensor of dimension {}".format(input.dim()))
        if offsets is None:
            raise ValueError("offsets has to be a 1D Tensor but got None")

        return self._backend.EmbeddingBag(
            self.max_norm, self.norm_type,
            self.scale_grad_by_freq, mode=self.mode
        )(self.weight, input, offsets)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号