objectives.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:CopyNet 作者: MultiPath 项目源码 文件源码
def gaussian_kl_divergence(mean, ln_var):
    """Computes the KL-divergence of Gaussian variables from the standard one.

    Given two variable ``mean`` representing :math:`\\mu` and ``ln_var``
    representing :math:`\\log(\\sigma^2)`, this function returns a variable
    representing the KL-divergence between the given multi-dimensional Gaussian
    :math:`N(\\mu, S)` and the standard Gaussian :math:`N(0, I)`

    .. math::

       D_{\\mathbf{KL}}(N(\\mu, S) \\| N(0, I)),

    where :math:`S` is a diagonal matrix such that :math:`S_{ii} = \\sigma_i^2`
    and :math:`I` is an identity matrix.

    Args:
        mean (~chainer.Variable): A variable representing mean of given
            gaussian distribution, :math:`\\mu`.
        ln_var (~chainer.Variable): A variable representing logarithm of
            variance of given gaussian distribution, :math:`\\log(\\sigma^2)`.

    Returns:
        ~chainer.Variable: A variable representing KL-divergence between
            given gaussian distribution and the standard gaussian.

    """
    var = T.exp(ln_var)
    return  0.5 * T.sum(mean * mean + var - ln_var - 1, 1)


# aliases
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号