def test_dirac_properties(self):
for as_variable in [True, False]:
for dims in [3, 4, 5]:
input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5, as_variable=as_variable)
init.dirac(input_tensor)
if as_variable:
input_tensor = input_tensor.data
c_out, c_in = input_tensor.size(0), input_tensor.size(1)
min_d = min(c_out, c_in)
# Check number of nonzeros is equivalent to smallest dim
assert torch.nonzero(input_tensor).size(0) == min_d
# Check sum of values (can have precision issues, hence assertEqual) is also equivalent
self.assertEqual(input_tensor.sum(), min_d)
评论列表
文章目录