test_torch.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def test_baddbmm(self):
        num_batches = 10
        M, N, O = 12, 8, 5
        b1 = torch.randn(num_batches, M, N)
        b2 = torch.randn(num_batches, N, O)
        res = torch.bmm(b1, b2)
        res2 = torch.Tensor().resize_as_(res).zero_()

        res2.baddbmm_(b1,b2)
        self.assertEqual(res2, res)

        res2.baddbmm_(1,b1,b2)
        self.assertEqual(res2, res*2)

        res2.baddbmm_(1,.5,b1,b2)
        self.assertEqual(res2, res*2.5)

        res3 = torch.baddbmm(1,res2,0,b1,b2)
        self.assertEqual(res3, res2)

        res4 = torch.baddbmm(1,res2,.5,b1,b2)
        self.assertEqual(res4, res*3)

        res5 = torch.baddbmm(0,res2,1,b1,b2)
        self.assertEqual(res5, res)

        res6 = torch.baddbmm(.1,res2,.5,b1,b2)
        self.assertEqual(res6, res2 * .1 + res * .5)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号