test_nn.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def test_data_parallel_sparse(self):
        l = nn.Embedding(10, 5, sparse=True).cuda(1)
        i = Variable(torch.LongTensor(20, 5).random_(0, 10).cuda(1))
        expected_out = l(i)
        loss = expected_out.sum()
        loss.backward()
        expected_grads = []
        for param in l.parameters():
            expected_grads.append(param.grad.clone())
        dev_ids_list = [(0, 1), (1, 0)]
        for dev_id in dev_ids_list:
            with torch.cuda.device(dev_id[0]):
                l.cuda()
                l.zero_grad()
                out = dp.data_parallel(l, i, dev_id)
                loss = out.sum()
                loss.backward()
                self.assertEqual(out.get_device(), dev_id[0])
                self.assertEqual(out.data, expected_out.data)
                for expected, param in zip(expected_grads, l.parameters()):
                    self.assertEqual(param.grad.data, expected.data)

        # Check for None device_ids
        l = l.cuda()
        out = dp.data_parallel(l, i)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号