def KLdiv(mu0, Lcov0, mu1, Lcov1):
"""Numpy KL calculation."""
tr, dist, ldet = 0., 0., 0.
D, n = mu0.shape
for m0, m1, L0, L1 in zip(mu0, mu1, Lcov0, Lcov1):
tr += np.trace(cho_solve((L1, True), L0.dot(L0.T)))
md = m1 - m0
dist += md.dot(cho_solve((L1, True), md))
ldet += logdet(L1) - logdet(L0)
KL = 0.5 * (tr + dist + ldet - D * n)
return KL
评论列表
文章目录