Bilinear.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def updateOutput(self, input):
        self._assertInput(input)

        # set up buffer:
        if self.buff2 is None:
            self.buff2 = input[0].new()
        self.buff2.resize_as_(input[1])

        # compute output scores:
        self.output.resize_(input[0].size(0), self.weight.size(0))
        for k in range(self.weight.size(0)):
            torch.mm(input[0], self.weight[k], out=self.buff2)
            self.buff2.mul_(input[1])
            torch.sum(self.buff2, 1, out=self.output.narrow(1, k, 1))

        if self.bias is not None:
            self.output.add_(self.bias.view(1, self.bias.nelement()).expand_as(self.output))

        return self.output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号