orthogonal.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:chainer-deconv 作者: germanRos 项目源码 文件源码
def __call__(self, array):
        xp = cuda.get_array_module(array)
        if not array.shape:  # 0-dim case
            array[...] = self.scale
        elif not array.size:
            raise ValueError('Array to be initialized must be non-empty.')
        else:
            # numpy.prod returns float value when the argument is empty.
            flat_shape = (len(array), int(numpy.prod(array.shape[1:])))
            if flat_shape[0] > flat_shape[1]:
                raise ValueError('Cannot make orthogonal system because'
                                 ' # of vectors ({}) is larger than'
                                 ' that of dimensions ({})'.format(
                                     flat_shape[0], flat_shape[1]))
            a = numpy.random.normal(size=flat_shape)
            # we do not have cupy.linalg.svd for now
            u, _, v = numpy.linalg.svd(a, full_matrices=False)
            # pick the one with the correct shape
            q = u if u.shape == flat_shape else v
            array[...] = xp.asarray(q.reshape(array.shape))
            array *= self.scale
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号