def test_dirac_identity(self):
batch, in_c, out_c, size, kernel_size = 8, 3, 4, 5, 3
# Test 1D
input_var = Variable(torch.randn(batch, in_c, size))
filter_var = Variable(torch.zeros(out_c, in_c, kernel_size))
init.dirac(filter_var)
output_var = F.conv1d(input_var, filter_var)
input_tensor, output_tensor = input_var.data, output_var.data # Variables do not support nonzero
self.assertEqual(input_tensor[:, :, 1:-1], output_tensor[:, :in_c, :]) # Assert in_c outputs are preserved
assert torch.nonzero(output_tensor[:, in_c:, :]).numel() == 0 # Assert extra outputs are 0
# Test 2D
input_var = Variable(torch.randn(batch, in_c, size, size))
filter_var = Variable(torch.zeros(out_c, in_c, kernel_size, kernel_size))
init.dirac(filter_var)
output_var = F.conv2d(input_var, filter_var)
input_tensor, output_tensor = input_var.data, output_var.data
self.assertEqual(input_tensor[:, :, 1:-1, 1:-1], output_tensor[:, :in_c, :, :])
assert torch.nonzero(output_tensor[:, in_c:, :, :]).numel() == 0
# Test 3D
input_var = Variable(torch.randn(batch, in_c, size, size, size))
filter_var = Variable(torch.zeros(out_c, in_c, kernel_size, kernel_size, kernel_size))
init.dirac(filter_var)
output_var = F.conv3d(input_var, filter_var)
input_tensor, output_tensor = input_var.data, output_var.data
self.assertEqual(input_tensor[:, :, 1:-1, 1:-1, 1:-1], output_tensor[:, :in_c, :, :])
assert torch.nonzero(output_tensor[:, in_c:, :, :, :]).numel() == 0
评论列表
文章目录