test_torch.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def test_scatterFill(self):
        m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
        elems_per_row = random.randint(1, 10)
        dim = random.randrange(3)

        val = random.random()
        idx_size = [m, n, o]
        idx_size[dim] = elems_per_row
        idx = torch.LongTensor().resize_(*idx_size)
        self._fill_indices(idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o)

        actual = torch.zeros(m, n, o).scatter_(dim, idx, val)
        expected = torch.zeros(m, n, o)
        for i in range(idx_size[0]):
            for j in range(idx_size[1]):
                for k in range(idx_size[2]):
                    ii = [i, j, k]
                    ii[dim] = idx[i,j,k]
                    expected[tuple(ii)] = val
        self.assertEqual(actual, expected, 0)

        idx[0][0][0] = 28
        self.assertRaises(RuntimeError, lambda: torch.zeros(m, n, o).scatter_(dim, idx, val))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号