def test_kaiming_normal(self):
for as_variable in [True, False]:
for use_a in [True, False]:
for dims in [2, 4]:
for mode in ['fan_in', 'fan_out']:
input_tensor = self._create_random_nd_tensor(dims, size_min=20, size_max=25,
as_variable=as_variable)
if use_a:
a = self._random_float(0.1, 2)
init.kaiming_normal(input_tensor, a=a, mode=mode)
else:
a = 0
init.kaiming_normal(input_tensor, mode=mode)
if as_variable:
input_tensor = input_tensor.data
fan_in = input_tensor.size(1)
fan_out = input_tensor.size(0)
if input_tensor.dim() > 2:
fan_in *= input_tensor[0, 0].numel()
fan_out *= input_tensor[0, 0].numel()
if mode == 'fan_in':
n = fan_in
else:
n = fan_out
expected_std = math.sqrt(2.0 / ((1 + a**2) * n))
assert self._is_normal(input_tensor, 0, expected_std)
评论列表
文章目录