util.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:dgm 作者: ashwindcruz 项目源码 文件源码
def gaussian_kl_divergence_standard(mu, ln_var):
   """D_{KL}(N(mu,var) | N(0,1))"""
   batch_size = float(mu.data.shape[0])
   S = F.exp(ln_var)
   D = mu.data.size

   KL_sum = 0.5*(F.sum(S, axis=1) + F.sum(mu*mu, axis=1) - F.sum(ln_var, axis=1) - D/batch_size)

   return KL_sum #/ batchsize
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号