def KLDGaussian(Q, N, eps=1e-8):
"""KL Divergence between two Gaussians
Assuming Q ~ N(mu0, A\sigma_0A') where A = I + vr^{T}
and N ~ N(mu1, \sigma_1)
"""
sum = lambda x: torch.sum(x, dim=1)
k = float(Q.mu.size()[1]) # dimension of distribution
mu0, v, r, mu1 = Q.mu, Q.v, Q.r, N.mu
s02, s12 = (Q.sigma).pow(2) + eps, (N.sigma).pow(2) + eps
a = sum(s02 * (1. + 2. * v * r) / s12) + sum(v.pow(2) / s12) * sum(r.pow(2) * s02) # trace term
b = sum((mu1 - mu0).pow(2) / s12) # difference-of-means term
c = 2. * (sum(N.logsigma - Q.logsigma) - torch.log(1. + sum(v * r) + eps)) # ratio-of-determinants term.
#
# print('trace: %s' % a)
# print('mu_diff: %s' % b)
# print('k: %s' % k)
# print('det: %s' % c)
return 0.5 * (a + b - k + c)
评论列表
文章目录