def test_ParallelTable(self):
input = torch.randn(3, 4, 5)
p = nn.ParallelTable()
p.add(nn.View(4,5,1))
p.add(nn.View(4,5,1))
p.add(nn.View(4,5,1))
m = nn.Sequential()
m.add(nn.SplitTable(0))
m.add(p)
m.add(nn.JoinTable(2))
# Check that these don't raise errors
p.__repr__()
str(p)
output = m.forward(input)
output2 = input.transpose(0,2).transpose(0,1)
self.assertEqual(output2, output)
gradInput = m.backward(input, output2)
self.assertEqual(gradInput, input)
评论列表
文章目录