def test_Copy(self):
input = torch.randn(3, 4).double()
c = nn.Copy(torch.DoubleTensor, torch.FloatTensor)
output = c.forward(input)
self.assertEqual(torch.typename(output), 'torch.FloatTensor')
self.assertEqual(output, input.float(), 1e-6)
gradInput = c.backward(input, output.fill_(1))
self.assertEqual(torch.typename(gradInput), 'torch.DoubleTensor')
self.assertEqual(gradInput, output.double(), 1e-6)
c.dontCast = True
c.double()
self.assertEqual(torch.typename(output), 'torch.FloatTensor')
# Check that these don't raise errors
c.__repr__()
str(c)
评论列表
文章目录