def gaussian_kl_divergence_standard(mu, ln_var):
"""D_{KL}(N(mu,var) | N(0,1))"""
batch_size = float(mu.data.shape[0])
S = F.exp(ln_var)
D = mu.data.size
KL_sum = 0.5*(F.sum(S, axis=1) + F.sum(mu*mu, axis=1) - F.sum(ln_var, axis=1) - D/batch_size)
return KL_sum #/ batchsize
评论列表
文章目录