toeplitz_test.py 文件源码

python
阅读 46 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def test_toeplitz_matmul_batch():
    cols = torch.Tensor([
        [1, 6, 4, 5],
        [2, 3, 1, 0],
        [1, 2, 3, 1],
    ])
    rows = torch.Tensor([
        [1, 2, 1, 1],
        [2, 0, 0, 1],
        [1, 5, 1, 0],
    ])

    rhs_mats = torch.randn(3, 4, 2)

    # Actual
    lhs_mats = torch.zeros(3, 4, 4)
    for i, (col, row) in enumerate(zip(cols, rows)):
        lhs_mats[i].copy_(utils.toeplitz.toeplitz(col, row))
    actual = torch.matmul(lhs_mats, rhs_mats)

    # Fast toeplitz
    res = utils.toeplitz.toeplitz_matmul(cols, rows, rhs_mats)
    assert utils.approx_equal(res, actual)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号