def test_toeplitz_matmul_batch():
cols = torch.Tensor([
[1, 6, 4, 5],
[2, 3, 1, 0],
[1, 2, 3, 1],
])
rows = torch.Tensor([
[1, 2, 1, 1],
[2, 0, 0, 1],
[1, 5, 1, 0],
])
rhs_mats = torch.randn(3, 4, 2)
# Actual
lhs_mats = torch.zeros(3, 4, 4)
for i, (col, row) in enumerate(zip(cols, rows)):
lhs_mats[i].copy_(utils.toeplitz.toeplitz(col, row))
actual = torch.matmul(lhs_mats, rhs_mats)
# Fast toeplitz
res = utils.toeplitz.toeplitz_matmul(cols, rows, rhs_mats)
assert utils.approx_equal(res, actual)
评论列表
文章目录