def test_type(self):
l = nn.Linear(10, 20)
net = nn.Container(
l=l,
l2=l,
empty=None,
)
net.float()
self.assertIsInstance(l.weight.data, torch.FloatTensor)
self.assertIsInstance(l.bias.data, torch.FloatTensor)
net.double()
self.assertIsInstance(l.weight.data, torch.DoubleTensor)
self.assertIsInstance(l.bias.data, torch.DoubleTensor)
评论列表
文章目录