def kl_normal(mu0, var0,
mu1=0.0, var1=1.0):
"""KL divergence for normal distribution.
Note that this is a simple version. We don't use covariance matrix (?) here. Instead,
var is the vector that indicates the elements in ?'s main diagonal (diag(?)).
:param mu0: ?0.
:param var0: diag(?0).
:param mu1: ?1.
:param var1: diag(?1).
:return: The KL divergence.
"""
e = 1e-4
var0 += e
if mu1 == 0.0 and var1 == 1.0:
kl = var0 + mu0 ** 2 - 1 - tf.log(var0)
else:
var1 += e
kl = var0 / var1 + (mu0 - mu1) ** 2 / var1 - 1 - tf.log(var0 / var1)
kl = 0.5 * tf.reduce_sum(kl, 1)
return kl
评论列表
文章目录