def test_toeplitz_matmul():
col = torch.Tensor([1, 6, 4, 5])
row = torch.Tensor([1, 2, 1, 1])
rhs_mat = torch.randn(4, 2)
# Actual
lhs_mat = utils.toeplitz.toeplitz(col, row)
actual = torch.matmul(lhs_mat, rhs_mat)
# Fast toeplitz
res = utils.toeplitz.toeplitz_matmul(col, row, rhs_mat)
assert utils.approx_equal(res, actual)
评论列表
文章目录