product.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:cupy 作者: cupy 项目源码 文件源码
def kron(a, b):
    """Returns the kronecker product of two arrays.

    Args:
        a (~cupy.ndarray): The first argument.
        b (~cupy.ndarray): The second argument.

    Returns:
        ~cupy.ndarray: Output array.

    .. seealso:: :func:`numpy.kron`

    """
    a_ndim = a.ndim
    b_ndim = b.ndim
    if a_ndim == 0 or b_ndim == 0:
        return cupy.multiply(a, b)

    ndim = b_ndim
    a_shape = a.shape
    b_shape = b.shape
    if a_ndim != b_ndim:
        if b_ndim > a_ndim:
            a_shape = (1,) * (b_ndim - a_ndim) + a_shape
        else:
            b_shape = (1,) * (a_ndim - b_ndim) + b_shape
            ndim = a_ndim

    axis = ndim - 1
    out = core.tensordot_core(a, b, None, a.size, b.size, 1, a_shape + b_shape)
    for _ in six.moves.range(ndim):
        out = core.concatenate_method(out, axis=axis)

    return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号