nce.py 文件源码

python
阅读 45 收藏 0 点赞 0 评论 0

项目:PyTorchText 作者: chenyuntc 项目源码 文件源码
def forward(self, input, indices=None):
        """
        Shape:
            - target_batch :math:`(N, E, 1+N_r)`where `N = length, E = embedding size, N_r = noise ratio`
        """

        if indices is None:
            return super(IndexLinear, self).forward(input)
        # the pytorch's [] operator BP can't correctly
        input = input.unsqueeze(1)
        target_batch = self.weight.index_select(0, indices.view(-1)).view(indices.size(0), indices.size(1), -1).transpose(1,2)
        bias = self.bias.index_select(0, indices.view(-1)).view(indices.size(0), 1, indices.size(1))
        out = torch.baddbmm(1, bias, 1, input, target_batch)
        return out.squeeze()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号