def test_sparse_repeat_1d():
sparse_1d = sparse_getitem(sparse, 1)
actual = sparse_1d.to_dense().repeat(3, 1)
res = sparse_repeat(sparse_1d, 3, 1)
assert torch.equal(actual, res.to_dense())
actual = sparse_1d.to_dense().repeat(2, 3)
res = sparse_repeat(sparse_1d, 2, 3)
assert torch.equal(actual, res.to_dense())
评论列表
文章目录