def test_matmul_out(self):
def check_matmul(size1, size2):
a = torch.randn(size1)
b = torch.randn(size2)
expected = torch.matmul(a, b)
out = torch.Tensor(expected.size()).zero_()
# make output non-contiguous
out = out.transpose(-1, -2).contiguous().transpose(-1, -2)
self.assertFalse(out.is_contiguous())
torch.matmul(a, b, out=out)
self.assertEqual(expected, out)
check_matmul((2, 3, 4), (2, 4, 5))
check_matmul((2, 3, 4), (4, 5))
评论列表
文章目录