def kl_div(mu1, lv1, lv2):
# KL Divergence between given normal and prior at N(0, sigma_2)
# Prior assumes mean at zero
# lns2 - lns1 + (s2^2 + (u1 - u2)**2)/ 2s2**2 - 0.5
if len(lv1.shape) == 2:
lv1 = F.expand_dims(lv1, 0)
mu1 = F.expand_dims(mu1, 0)
lv2 = F.broadcast_to(lv2, lv1.shape)
v12 = F.exp(lv1)**2.0
v22 = F.exp(lv2)**2.0
return lv2 - lv1 + .5 * v12 / v22 + .5 * mu1**2. / v22 - .5
评论列表
文章目录