def _getAllowedShapes(self, shape):
''' Return set of allowed shapes that can be squeezed into given shape.
Examples
--------
>>> PB = ParamBag() # fixing K,D doesn't matter
>>> PB._getAllowedShapes(())
set([()])
>>> PB._getAllowedShapes((1,))
set([(), (1,)])
>>> aSet = PB._getAllowedShapes((23,))
>>> sorted(aSet)
[(23,)]
>>> sorted(PB._getAllowedShapes((3,1)))
[(3,), (3, 1)]
>>> sorted(PB._getAllowedShapes((1,1)))
[(), (1,), (1, 1)]
'''
assert isinstance(shape, tuple)
allowedShapes = set()
if len(shape) == 0:
allowedShapes.add(tuple())
return allowedShapes
shapeVec = np.asarray(shape, dtype=np.int32)
onesMask = shapeVec == 1
keepMask = np.logical_not(onesMask)
nOnes = sum(onesMask)
for b in range(2**nOnes):
bStr = np.binary_repr(b)
bStr = '0' * (nOnes - len(bStr)) + bStr
keepMask[onesMask] = np.asarray([int(x) > 0 for x in bStr])
curShape = shapeVec[keepMask]
allowedShapes.add(tuple(curShape))
return allowedShapes
评论列表
文章目录