test_nn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def test_kaiming_normal(self):
        for as_variable in [True, False]:
            for use_a in [True, False]:
                for dims in [2, 4]:
                    for mode in ['fan_in', 'fan_out']:
                        input_tensor = self._create_random_nd_tensor(dims, size_min=20, size_max=25,
                                                                     as_variable=as_variable)
                        if use_a:
                            a = self._random_float(0.1, 2)
                            init.kaiming_normal(input_tensor, a=a, mode=mode)
                        else:
                            a = 0
                            init.kaiming_normal(input_tensor, mode=mode)

                        if as_variable:
                            input_tensor = input_tensor.data

                        fan_in = input_tensor.size(1)
                        fan_out = input_tensor.size(0)
                        if input_tensor.dim() > 2:
                            fan_in *= input_tensor[0, 0].numel()
                            fan_out *= input_tensor[0, 0].numel()

                        if mode == 'fan_in':
                            n = fan_in
                        else:
                            n = fan_out

                        expected_std = math.sqrt(2.0 / ((1 + a**2) * n))
                        assert self._is_normal(input_tensor, 0, expected_std)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号