test_nn.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def test_orthogonal(self):
        for as_variable in [True, False]:
            for use_gain in [True, False]:
                for tensor_size in [[3, 4], [4, 3], [20, 2, 3, 4], [2, 3, 4, 5]]:
                    input_tensor = torch.zeros(tensor_size)
                    gain = 1.0

                    if as_variable:
                        input_tensor = Variable(input_tensor)

                    if use_gain:
                        gain = self._random_float(0.1, 2)
                        init.orthogonal(input_tensor, gain=gain)
                    else:
                        init.orthogonal(input_tensor)

                    if as_variable:
                        input_tensor = input_tensor.data

                    rows, cols = tensor_size[0], reduce(mul, tensor_size[1:])
                    flattened_tensor = input_tensor.view(rows, cols)
                    if rows > cols:
                        self.assertEqual(torch.mm(flattened_tensor.t(), flattened_tensor),
                                         torch.eye(cols) * gain ** 2, prec=1e-6)
                    else:
                        self.assertEqual(torch.mm(flattened_tensor, flattened_tensor.t()),
                                         torch.eye(rows) * gain ** 2, prec=1e-6)


# Generates rand tensor with non-equal values. This ensures that duplicate
# values won't be causing test failure for modules like MaxPooling.
# size should be small, otherwise randperm fails / long overflows.
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号