def make_node(self, a, b):
a = as_sparse_variable(a)
assert a.format in ["csr", "csc", "bsr"]
if not _is_sparse_variable(a):
raise TypeError('First argument must be of type SparseVariable '
'or SparseConstant')
dtype_out = scalar.upcast(a.type.dtype, b.type.dtype)
if b.type.ndim != 2:
raise NotImplementedError('non-matrix b')
if _is_sparse_variable(b):
return gof.Apply(self, [a, b],
[SparseType(a.type.format, dtype_out)()])
else:
return gof.Apply(self, [a, b],
[tensor.tensor(dtype_out,
(False, b.type.broadcastable[1]))])
评论列表
文章目录