MM.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def updateOutput(self, input):
        assert len(input) == 2
        a, b = input
        assert a.ndimension() == 2 or a.ndimension() == 3
        assert a.dim() == b.dim()

        if a.ndimension() == 2:
            if self.transA:
                a = a.t()
            if self.transB:
                b = b.t()
            self.output.resize_(a.size(0), b.size(1))
            torch.mm(a, b, out=self.output)
        else:
            if self.transA:
                a = a.transpose(2, 3)
            if self.transB:
                b = b.transpose(2, 3)

            self.output.resize_(a.size(0), a.size(1), b.size(2))
            torch.bmm(a, b, out=self.output)

        return self.output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号