def make_node(self, x, index):
assert isinstance(x.type, TypedListType)
if not isinstance(index, Variable):
if isinstance(index, slice):
index = Constant(SliceType(), index)
return Apply(self, [x, index], [x.type()])
else:
index = T.constant(index, ndim=0, dtype='int64')
return Apply(self, [x, index], [x.ttype()])
if isinstance(index.type, SliceType):
return Apply(self, [x, index], [x.type()])
elif isinstance(index, T.TensorVariable) and index.ndim == 0:
assert index.dtype == 'int64'
return Apply(self, [x, index], [x.ttype()])
else:
raise TypeError('Expected scalar or slice as index.')
评论列表
文章目录