ParamBag.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:bnpy 作者: bnpy 项目源码 文件源码
def _getAllowedShapes(self, shape):
        ''' Return set of allowed shapes that can be squeezed into given shape.

        Examples
        --------
        >>> PB = ParamBag() # fixing K,D doesn't matter
        >>> PB._getAllowedShapes(())
        set([()])
        >>> PB._getAllowedShapes((1,))
        set([(), (1,)])
        >>> aSet = PB._getAllowedShapes((23,))
        >>> sorted(aSet)
        [(23,)]
        >>> sorted(PB._getAllowedShapes((3,1)))
        [(3,), (3, 1)]
        >>> sorted(PB._getAllowedShapes((1,1)))
        [(), (1,), (1, 1)]
        '''
        assert isinstance(shape, tuple)
        allowedShapes = set()
        if len(shape) == 0:
            allowedShapes.add(tuple())
            return allowedShapes
        shapeVec = np.asarray(shape, dtype=np.int32)
        onesMask = shapeVec == 1
        keepMask = np.logical_not(onesMask)
        nOnes = sum(onesMask)
        for b in range(2**nOnes):
            bStr = np.binary_repr(b)
            bStr = '0' * (nOnes - len(bStr)) + bStr
            keepMask[onesMask] = np.asarray([int(x) > 0 for x in bStr])
            curShape = shapeVec[keepMask]
            allowedShapes.add(tuple(curShape))
        return allowedShapes
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号