def test_cat(self):
SIZE = 10
# 2-arg cat
for dim in range(3):
x = torch.rand(13, SIZE, SIZE).transpose(0, dim)
y = torch.rand(17, SIZE, SIZE).transpose(0, dim)
res1 = torch.cat((x, y), dim)
self.assertEqual(res1.narrow(dim, 0, 13), x, 0)
self.assertEqual(res1.narrow(dim, 13, 17), y, 0)
# Check iterables
for dim in range(3):
x = torch.rand(13, SIZE, SIZE).transpose(0, dim)
y = torch.rand(17, SIZE, SIZE).transpose(0, dim)
z = torch.rand(19, SIZE, SIZE).transpose(0, dim)
res1 = torch.cat((x, y, z), dim)
self.assertEqual(res1.narrow(dim, 0, 13), x, 0)
self.assertEqual(res1.narrow(dim, 13, 17), y, 0)
self.assertEqual(res1.narrow(dim, 30, 19), z, 0)
self.assertRaises(ValueError, lambda: torch.cat([]))
评论列表
文章目录