Bilinear.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def updateOutput(self, input):
        self._assertInput(input)

        # set up buffer:
        self.buff2 = self.buff2 or input[0].new()
        self.buff2.resize_as_(input[1])

        # compute output scores:
        self.output.resize_(input[0].size(0), self.weight.size(0))
        for k in range(self.weight.size(0)):
            torch.mm(self.buff2, input[0], self.weight[k])
            self.buff2.mul_(input[1])
            torch.sum(self.output.narrow(1, k, 1), self.buff2, 1)

        if self.bias:
            self.output.add_(self.bias.view(1, self.bias.nelement()).expand_as(self.output))

        return self.output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号