hermitian.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:factorix 作者: gbouchar 项目源码 文件源码
def hermitian_dot(u, v):
    """
    Hermitian dot product between multiple embeddings given by rows.
    :param u: first matrix of n embeddings
    :param v: second matrix of m embeddings
    :param alpha: weight of the real part in the response
    :return: a pair of n * m matrix of Hermitian inner products between all vector combinations:
        - Re(<u_i, v_j>) for the first output
        - Im(<u_i, v_j>) for the second output
    >>> embeddings = np.array([[1., 1, 0, 3], [0, 1, 0, 1], [-1, 1, 1, 5]])
    >>> print(hermitian_dot(embeddings, embeddings.T))
    (array([[ 11.,   4.,  15.],
           [  4.,   2.,   6.],
           [ 15.,   6.,  28.]]), array([[ 0., -2.,  3.],
           [ 2.,  0.,  4.],
           [-3., -4.,  0.]]))
    """
    rk = u.shape[1] // 2
    u_re = u[:, :rk]
    u_im = u[:, rk:]
    v_re = v[:rk, :]
    v_im = v[rk:, :]
    return np.dot(u_re, v_re) + np.dot(u_im, v_im), np.dot(u_re, v_im) - np.dot(u_im, v_re)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号