def structured_dot(x, y):
"""
Structured Dot is like dot, except that only the
gradient wrt non-zero elements of the sparse matrix
`a` are calculated and propagated.
The output is presumed to be a dense matrix, and is represented by a
TensorType instance.
Parameters
----------
a
A sparse matrix.
b
A sparse or dense matrix.
Returns
-------
A sparse matrix
The dot product of `a` and `b`.
Notes
-----
The grad implemented is structured.
"""
# @todo: Maybe the triple-transposition formulation (when x is dense)
# is slow. See if there is a direct way to do this.
# (JB 20090528: Transposing tensors and sparse matrices is constant-time,
# inplace, and fast.)
if hasattr(x, 'getnnz'):
x = as_sparse_variable(x)
assert x.format in ["csr", "csc"]
if hasattr(y, 'getnnz'):
y = as_sparse_variable(y)
assert y.format in ["csr", "csc"]
x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y)
if not x_is_sparse_variable and not y_is_sparse_variable:
raise TypeError('structured_dot requires at least one sparse argument')
if x_is_sparse_variable:
return _structured_dot(x, y)
else:
assert y_is_sparse_variable
return _structured_dot(y.T, x.T).T
评论列表
文章目录