compute_error.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:AND4NMF 作者: PrincetonML 项目源码 文件源码
def compute_error(A_in, Ag_in):
    A = A_in
    Ag = Ag_in

    #reallign
    D = A.shape[1]
    inner = np.zeros((D, D))
    for i in range(D):
        for j in range(D):
            inner[i, j] = np.asscalar(A[:, i].transpose() * Ag[:, j] )/(norm(A[:, i]) * norm(Ag[:, j]))

    max = np.argmax(inner, axis = 0)
    P = np.asmatrix(np.zeros((D, D)))
    for i in range(D):
        P[i, max[i]] = 1

    # print "normalize the rows of A and A^*"
    inv_norm_A = np.asarray(1.0 / np.apply_along_axis(norm, 0, A))
    A = A * np.diag(inv_norm_A)
    inv_norm_Ag = np.asarray(1.0 / np.apply_along_axis(norm, 0, Ag))
    Ag = Ag * np.diag(inv_norm_Ag)

    u = np.asmatrix(np.ones((1, D)))
    #for each A_i^* we try to find the A_i that is closest to A_i^*
    error = 0
    for i in range(D):
        Ag_i = Ag[:, i]
        inner_product = np.asmatrix(Ag_i.transpose() * A)
        norm_A = np.asmatrix(np.diag(A.transpose() * A))
        z = np.divide(inner_product, norm_A).transpose()
        z = np.asarray(z).flatten().transpose()
        scalar = np.diag(z)
        As = A * scalar
        diff = np.apply_along_axis(norm, 0, As - Ag_i * u)
        # min_idx = np.argmin(diff)
        # print 'for Ag_%d: A_%d' % (i, min_idx)
        difmin = np.amin(diff)
        difmin = difmin * difmin
        error = error + difmin

    return error
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号