test_torch.py 文件源码

python
阅读 62 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def test_cat(self):
        SIZE = 10
        # 2-arg cat
        for dim in range(3):
            x = torch.rand(13, SIZE, SIZE).transpose(0, dim)
            y = torch.rand(17, SIZE, SIZE).transpose(0, dim)
            res1 = torch.cat((x, y), dim)
            self.assertEqual(res1.narrow(dim, 0, 13), x, 0)
            self.assertEqual(res1.narrow(dim, 13, 17), y, 0)

        # Check iterables
        for dim in range(3):
            x = torch.rand(13, SIZE, SIZE).transpose(0, dim)
            y = torch.rand(17, SIZE, SIZE).transpose(0, dim)
            z = torch.rand(19, SIZE, SIZE).transpose(0, dim)

            res1 = torch.cat((x, y, z), dim)
            self.assertEqual(res1.narrow(dim, 0, 13), x, 0)
            self.assertEqual(res1.narrow(dim, 13, 17), y, 0)
            self.assertEqual(res1.narrow(dim, 30, 19), z, 0)
            self.assertRaises(ValueError, lambda: torch.cat([]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号