def make_node(self, a, s=None):
a = T.as_tensor_variable(a)
if a.ndim < 2:
raise TypeError('%s: input must have dimension > 2, with first dimension batches' %
self.__class__.__name__)
if s is None:
s = a.shape[1:]
s = T.as_tensor_variable(s)
else:
s = T.as_tensor_variable(s)
if (not s.dtype.startswith('int')) and \
(not s.dtype.startswith('uint')):
raise TypeError('%s: length of the transformed axis must be'
' of type integer' % self.__class__.__name__)
return gof.Apply(self, [a, s], [self.output_type(a)()])
评论列表
文章目录