def test_type_conversions(self):
x = Variable(torch.randn(5, 5))
self.assertIs(type(x.float().data), torch.FloatTensor)
self.assertIs(type(x.int().data), torch.IntTensor)
if torch.cuda.is_available():
self.assertIs(type(x.float().cuda().data), torch.cuda.FloatTensor)
self.assertIs(type(x.int().cuda().data), torch.cuda.IntTensor)
self.assertIs(type(x.int().cuda().cpu().data), torch.IntTensor)
if torch.cuda.device_count() > 2:
x2 = x.float().cuda(1)
self.assertIs(type(x2.data), torch.cuda.FloatTensor)
self.assertIs(x2.get_device(), 1)
x2 = x.float().cuda()
self.assertIs(type(x2.data), torch.cuda.FloatTensor)
self.assertIs(x2.get_device(), 0)
x2 = x2.cuda(1)
self.assertIs(type(x2.data), torch.cuda.FloatTensor)
self.assertIs(x2.get_device(), 1)
for t in [torch.DoubleTensor, torch.FloatTensor, torch.IntTensor, torch.ByteTensor]:
for var in (True, False):
y = torch.randn(5, 5).type(t)
if var:
y = Variable(y)
self.assertIs(type(x.type_as(y).data), t)
评论列表
文章目录