def make_node(self, a, s=None):
a = T.as_tensor_variable(a)
if a.ndim < 3:
raise TypeError('%s: input must have dimension >= 3, with ' %
self.__class__.__name__ +
'first dimension batches and last real/imag parts')
if s is None:
s = a.shape[1:-1]
s = T.set_subtensor(s[-1], (s[-1] - 1) * 2)
s = T.as_tensor_variable(s)
else:
s = T.as_tensor_variable(s)
if (not s.dtype.startswith('int')) and \
(not s.dtype.startswith('uint')):
raise TypeError('%s: length of the transformed axis must be'
' of type integer' % self.__class__.__name__)
return gof.Apply(self, [a, s], [self.output_type(a)()])
评论列表
文章目录