def test_data_parallel_sparse(self):
l = nn.Embedding(10, 5, sparse=True).cuda(1)
i = Variable(torch.LongTensor(20, 5).random_(0, 10).cuda(1))
expected_out = l(i)
loss = expected_out.sum()
loss.backward()
expected_grads = []
for param in l.parameters():
expected_grads.append(param.grad.clone())
dev_ids_list = [(0, 1), (1, 0)]
for dev_id in dev_ids_list:
with torch.cuda.device(dev_id[0]):
l.cuda()
l.zero_grad()
out = dp.data_parallel(l, i, dev_id)
loss = out.sum()
loss.backward()
self.assertEqual(out.get_device(), dev_id[0])
self.assertEqual(out.data, expected_out.data)
for expected, param in zip(expected_grads, l.parameters()):
self.assertEqual(param.grad.data, expected.data)
# Check for None device_ids
l = l.cuda()
out = dp.data_parallel(l, i)
评论列表
文章目录