def test_xavier_normal(self):
for as_variable in [True, False]:
for use_gain in [True, False]:
for dims in [2, 4]:
input_tensor = self._create_random_nd_tensor(dims, size_min=20, size_max=25,
as_variable=as_variable)
gain = 1
if use_gain:
gain = self._random_float(0.1, 2)
init.xavier_normal(input_tensor, gain=gain)
else:
init.xavier_normal(input_tensor)
if as_variable:
input_tensor = input_tensor.data
fan_in = input_tensor.size(1)
fan_out = input_tensor.size(0)
if input_tensor.dim() > 2:
fan_in *= input_tensor[0, 0].numel()
fan_out *= input_tensor[0, 0].numel()
expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out))
assert self._is_normal(input_tensor, 0, expected_std)
评论列表
文章目录