def infer_shape(self, node, i_shapes):
r, shp = node.inputs[0:2]
# if shp is a constant array of len 0, then it means 'automatic shape'
unknown_shape = len(getattr(shp, 'data', [0, 1, 2])) == 0
# if ndim_added == 0 and shape != () then shape
if self.ndim_added == 0 and not unknown_shape:
sample_shp = shp
else:
# if shape == () then it will depend on args
# if ndim_added != 0 and shape != () then it will depend on args
# Use the default infer_shape implementation.
raise tensor.ShapeError()
return [None, [sample_shp[i] for i in xrange(node.outputs[1].ndim)]]
评论列表
文章目录