def test_Index(self):
net = nn.Index(0)
# test 1D
input = [torch.Tensor((10, 20, 30)), torch.LongTensor((0, 1, 1, 2))]
output = net.forward(input)
self.assertEqual(output, torch.Tensor((10, 20, 20, 30)))
gradOutput = torch.Tensor((1, 1, 1, 3))
gradInput = net.backward(input, gradOutput)
self.assertEqual(gradInput[0], torch.Tensor((1, 2, 3)))
# test 2D
input = [torch.Tensor(((10, 20), (30, 40))), torch.LongTensor((0, 0))]
output = net.forward(input)
self.assertEqual(output, torch.Tensor(((10, 20), (10, 20))))
gradOutput = torch.Tensor(((1, 2), (1, 2)))
gradInput = net.backward(input, gradOutput)
self.assertEqual(gradInput[0], torch.Tensor(((2, 4), (0, 0))))
# Check that these don't raise errors
net.__repr__()
str(net)
评论列表
文章目录