test_basic.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:cupy 作者: cupy 项目源码 文件源码
def _check_copyto_where_multigpu_raises(self, dtype, ngpus):
        def get_numpy():
            a = testing.shaped_arange((2, 3, 4), numpy, dtype)
            b = testing.shaped_reverse_arange((2, 3, 4), numpy, dtype)
            c = testing.shaped_arange((2, 3, 4), numpy, '?')
            numpy.copyto(a, b, where=c)
            return a

        for dev1, dev2, dev3, dev4 in itertools.product(*[range(ngpus)] * 4):
            if dev1 == dev2 == dev3 == dev4:
                continue
            if not dev1 <= dev2 <= dev3 <= dev4:
                continue

            with cuda.Device(dev1):
                a = testing.shaped_arange((2, 3, 4), cupy, dtype)
            with cuda.Device(dev2):
                b = testing.shaped_reverse_arange((2, 3, 4), cupy, dtype)
            with cuda.Device(dev3):
                c = testing.shaped_arange((2, 3, 4), cupy, '?')
            with cuda.Device(dev4):
                with six.assertRaisesRegex(
                        self, ValueError,
                        '^Array device must be same as the current device'):
                    cupy.copyto(a, b, where=c)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号