test_nn.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def test_xavier_uniform(self):
        for as_variable in [True, False]:
            for use_gain in [True, False]:
                for dims in [2, 4]:
                    input_tensor = self._create_random_nd_tensor(dims, size_min=20, size_max=25,
                                                                 as_variable=as_variable)
                    gain = 1

                    if use_gain:
                        gain = self._random_float(0.1, 2)
                        init.xavier_uniform(input_tensor, gain=gain)
                    else:
                        init.xavier_uniform(input_tensor)

                    if as_variable:
                        input_tensor = input_tensor.data

                    fan_in = input_tensor.size(1)
                    fan_out = input_tensor.size(0)
                    if input_tensor.dim() > 2:
                        fan_in *= input_tensor[0, 0].numel()
                        fan_out *= input_tensor[0, 0].numel()

                    expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out))
                    bounds = expected_std * math.sqrt(3)
                    assert self._is_uniform(input_tensor, -bounds, bounds)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号