test_nn.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码
def _test_gather(self, output_device):
        inputs = (
            Variable(torch.randn(2, 4).cuda(0), requires_grad=True),
            Variable(torch.randn(2, 4).cuda(1), requires_grad=True)
        )
        result = dp.gather(inputs, output_device)
        self.assertEqual(result.size(), torch.Size([4, 4]))
        self.assertEqual(result[:2], inputs[0])
        self.assertEqual(result[2:], inputs[1])
        if output_device != -1:
            self.assertEqual(result.get_device(), output_device)
        else:
            self.assertFalse(result.is_cuda)
        grad = torch.randn(4, 4)
        if output_device != -1:
            grad = grad.cuda(output_device)
        result.backward(grad)
        self.assertEqual(inputs[0].grad.data, grad[:2])
        self.assertEqual(inputs[1].grad.data, grad[2:])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号