def make_node(self, x, index):
x = as_sparse_variable(x)
assert x.format in ["csr", "csc"]
assert len(index) == 2
input_op = [x]
for ind in index:
if isinstance(ind, slice):
raise Exception("GetItemScalar called with a slice as index!")
# in case of indexing using int instead of theano variable
elif isinstance(ind, integer_types):
ind = theano.tensor.constant(ind)
input_op += [ind]
# in case of indexing using theano variable
elif ind.ndim == 0:
input_op += [ind]
else:
raise NotImplemented()
return gof.Apply(self, input_op, [tensor.scalar(dtype=x.dtype)])
评论列表
文章目录