def kron(a, b):
"""Returns the kronecker product of two arrays.
Args:
a (~cupy.ndarray): The first argument.
b (~cupy.ndarray): The second argument.
Returns:
~cupy.ndarray: Output array.
.. seealso:: :func:`numpy.kron`
"""
a_ndim = a.ndim
b_ndim = b.ndim
if a_ndim == 0 or b_ndim == 0:
return cupy.multiply(a, b)
ndim = b_ndim
a_shape = a.shape
b_shape = b.shape
if a_ndim != b_ndim:
if b_ndim > a_ndim:
a_shape = (1,) * (b_ndim - a_ndim) + a_shape
else:
b_shape = (1,) * (a_ndim - b_ndim) + b_shape
ndim = a_ndim
axis = ndim - 1
out = core.tensordot_core(a, b, None, a.size, b.size, 1, a_shape + b_shape)
for _ in six.moves.range(ndim):
out = core.concatenate_method(out, axis=axis)
return out
评论列表
文章目录