def make_node(self, x, shp):
x = as_tensor_variable(x)
shp_orig = shp
shp = as_tensor_variable(shp, ndim=1)
if not (shp.dtype.startswith('int') or
(isinstance(shp, TensorConstant) and shp.data.size == 0)):
# It raises an error if shp is not of integer type,
# except when shp is constant and empty
# (in this case, shp.dtype does not matter anymore).
raise TypeError("Shape must be integers", shp, shp.dtype)
assert shp.ndim == 1
if isinstance(shp, TensorConstant):
bcast = [s == 1 for s in shp.data]
return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcast)])
else:
bcasts = [False] * self.ndim
shp_list = shp_orig
if hasattr(shp_orig, "ndim") and shp_orig.ndim == 0:
shp_list = [shp_orig]
for index in xrange(self.ndim):
y = shp_list[index]
y = as_tensor_variable(y)
# Try to see if we can infer that y has a constant value of 1.
# If so, that dimension should be broadcastable.
try:
bcasts[index] = (
hasattr(y, 'get_scalar_constant_value') and
y.get_scalar_constant_value() == 1)
except NotScalarConstantError:
pass
return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcasts)])
评论列表
文章目录