test_multiprocessing.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码
def test_cuda_small_tensors(self):
        # Check multiple small tensors which will likely use the same
        # underlying cached allocation
        ctx = mp.get_context('spawn')
        tensors = []
        for i in range(5):
            tensors += [torch.arange(i * 5, (i + 1) * 5).cuda()]

        inq = ctx.Queue()
        outq = ctx.Queue()
        inq.put(tensors)
        p = ctx.Process(target=sum_tensors, args=(inq, outq))
        p.start()

        results = []
        for i in range(5):
            results.append(outq.get())
        p.join()

        for i, tensor in enumerate(tensors):
            v, device, tensor_size, storage_size = results[i]
            self.assertEqual(v, torch.arange(i * 5, (i + 1) * 5).sum())
            self.assertEqual(device, 0)
            self.assertEqual(tensor_size, 5)
            self.assertEqual(storage_size, 5)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号