prob.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:tf_practice 作者: juho-lee 项目源码 文件源码
def rect_gaussian_kld(mean, log_var, mean0=0., log_var0=0., reduce_mean=True):
    def phi(x):
        return tf.exp(-0.5*tf.square(x))/np.sqrt(2*np.pi)
    def Phi(x):
        return 0.5 + 0.5*tf.erf(x/np.sqrt(2))

    smean = tf.square(mean)
    var = tf.exp(log_var)
    log_std = 0.5*log_var
    std = tf.exp(log_std)

    smean0 = tf.square(mean0)
    var0 = tf.exp(log_var0)
    log_std0 = 0.5*log_var0
    std0 = tf.exp(log_std0)

    tol = 1.0e-10
    pzero = Phi(-mean/std)
    kld = pzero*(tf.log(pzero+tol) - tf.log(Phi(-mean0/std0)+tol))
    kld += (1-pzero)*(log_std0 - log_std + 0.5*(smean0/var0 - smean/var))
    kld += (0.5/var0 - 0.5/var)*((smean + var)*(1-pzero) + mean*std*phi(-mean/std))
    kld -= (mean0/var0 - mean/var)*(mean*(1-pzero) + std*phi(-mean/std))
    kld = tf.reduce_sum(kld, 1)
    if reduce_mean:
        kld = tf.reduce_mean(kld)
    return kld
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号