def test_dilate(self):
input = Variable(torch.arange(0, 13).view(1, 1, 13))
dilated, _ = dilate(input, 1)
self.assertEqual(dilated.size(), (1, 1, 13))
self.assertEqual(dilated[0, 0, 4].data[0], 4)
dilated, _ = dilate(input, 2)
self.assertEqual(dilated.size(), (2, 1, 7))
self.assertEqual(dilated[1, 0, 2].data[0], 4)
dilated, _ = dilate(input, 4)
self.assertEqual(dilated.size(), (4, 1, 4))
self.assertEqual(dilated[3, 0, 1].data[0], 4)
dilated, _ = dilate(dilated, 1)
self.assertEqual(dilated.size(), (1, 1, 16))
self.assertEqual(dilated[0, 0, 7].data[0], 4)
评论列表
文章目录