def make_node(self, x, axis, splits):
"""WRITEME"""
x = as_tensor_variable(x)
axis = as_tensor_variable(axis)
splits = as_tensor_variable(splits)
if splits.type not in int_vector_types:
raise TypeError('splits must have type tensor.lvector',
splits.type)
if axis.type not in int_types:
raise TypeError('axis must have type lscalar', axis.type)
# # The following lines are necessary if we allow splits of zero
# if isinstance(axis, gof.Constant):
# x = unbroadcast(x, int(axis.data))
# else:
# x = unbroadcast(x, *range(x.type.ndim))
inputs = [x, axis, splits]
outputs = [x.type() for i in xrange(self.len_splits)]
return Apply(self, inputs, outputs)
评论列表
文章目录