def make_node(self, alpha, x, y, z):
if not _is_sparse_variable(x) and not _is_sparse_variable(y):
# If x and y are tensor, we don't want to use this class
# We should use Dot22 and Gemm in that case.
raise TypeError(x)
dtype_out = scalar.upcast(alpha.type.dtype, x.type.dtype,
y.type.dtype, z.type.dtype)
alpha = tensor.as_tensor_variable(alpha)
z = tensor.as_tensor_variable(z)
assert z.ndim == 2
assert alpha.type.broadcastable == (True,) * alpha.ndim
if not _is_sparse_variable(x):
x = tensor.as_tensor_variable(x)
assert y.format in ["csr", "csc"]
assert x.ndim == 2
if not _is_sparse_variable(y):
y = tensor.as_tensor_variable(y)
assert x.format in ["csr", "csc"]
assert y.ndim == 2
return gof.Apply(self, [alpha, x, y, z],
[tensor.tensor(dtype=dtype_out,
broadcastable=(False, False))])
评论列表
文章目录