W_means_class.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Multilevel-Wasserstein-Means 作者: moonfolk 项目源码 文件源码
def update_atoms(S, b, H, a, Z, W, labels, weight=[1.,1.], max_iter=50):
    M = len(Z)
    M_h = [cdist(H[labels[m]].T, S.T, metric='euclidean') for m in range(M)]
    M_z = [cdist(Z[m].T, S.T, metric='euclidean') for m in range(M)]
    T_h = [algo3(a[labels[m]], b[m], M_h[m], max_iter=max_iter)[0] for m in range(M)]
    T_z = [algo3(W[m], b[m], M_z[m], max_iter=max_iter)[0] for m in range(M)]
    k = S.shape[1]
    for l in range(k):
        z_part = weight[0]*np.sum([(z*t[:,l]).sum(axis=1) for (z,t) in zip(Z, T_z)], axis=0)
        z_weight = weight[0]*np.sum([t[:,l].sum() for t in T_z])
        h_part = weight[1]*np.sum([(h*th[:,l]).sum(axis=1) for (h,th) in zip([H[labels[m]] for m in range(M)], T_h)], axis=0)
        h_weight = weight[1]*np.sum([th[:,l].sum() for th in T_h])
        S[:,l] = (z_part + h_part)/(z_weight + h_weight)

#################### Learning functions for algorithms

## No constraint algorithm

## Initialization based on k-means
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号