def test_sparse_repeat_2d():
actual = sparse.to_dense().repeat(3, 2)
res = sparse_repeat(sparse, 3, 2)
assert torch.equal(actual, res.to_dense())
actual = sparse.to_dense().repeat(1, 2)
res = sparse_repeat(sparse, 1, 2)
assert torch.equal(actual, res.to_dense())
actual = sparse.to_dense().repeat(3, 1)
res = sparse_repeat(sparse, 3, 1)
assert torch.equal(actual, res.to_dense())
评论列表
文章目录