ops.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:photinia 作者: XoriieInpottn 项目源码 文件源码
def kl_normal(mu0, var0,
              mu1=0.0, var1=1.0):
    """KL divergence for normal distribution.
    Note that this is a simple version. We don't use covariance matrix (?) here. Instead, 
    var is the vector that indicates the elements in ?'s main diagonal (diag(?)).

    :param mu0: ?0.
    :param var0: diag(?0).
    :param mu1: ?1.
    :param var1: diag(?1).
    :return: The KL divergence.
    """
    e = 1e-4
    var0 += e
    if mu1 == 0.0 and var1 == 1.0:
        kl = var0 + mu0 ** 2 - 1 - tf.log(var0)
    else:
        var1 += e
        kl = var0 / var1 + (mu0 - mu1) ** 2 / var1 - 1 - tf.log(var0 / var1)
    kl = 0.5 * tf.reduce_sum(kl, 1)
    return kl
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号