def make_node(self, x):
t_x = as_tensor_variable(x)
if self.outdim < 1 or (x.ndim and self.outdim > x.ndim):
raise ValueError('invalid output ndimensions (%i) for tensor of '
'rank %i' % (self.outdim, t_x.ndim))
# Infer the broadcastable pattern of the output. For every dimension
# unaffected by the flatten, the broadcast flag should be unchanged.
# For the dimension resulting from the collapse of other dimensions,
# it should be broadcastable iff all the collapsed dimensions were
# broadcastable.
bcast_kept_dims = x.broadcastable[:self.outdim - 1]
bcast_new_dim = python_all(x.broadcastable[self.outdim - 1:])
broadcastable = bcast_kept_dims + (bcast_new_dim,)
return gof.Apply(self, [t_x], [tensor(x.type.dtype,
broadcastable)])
评论列表
文章目录