def update_kl_loss(p, lambdas, T, Cs):
"""
Updates C according to the KL Loss kernel with the S Ts couplings calculated at each iteration
Parameters
----------
p : ndarray, shape (N,)
weights in the targeted barycenter
lambdas : list of the S spaces' weights
T : list of S np.ndarray(ns,N)
the S Ts couplings calculated at each iteration
Cs : list of S ndarray, shape(ns,ns)
Metric cost matrices
Returns
----------
C : ndarray, shape (ns,ns)
updated C matrix
"""
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
for s in range(len(T))])
ppt = np.outer(p, p)
return np.exp(np.divide(tmpsum, ppt))
评论列表
文章目录