x_analytical_values.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:adversarial-variational-bayes 作者: gdikov 项目源码 文件源码
def analytical_value_d_kullback_leibler(distr1, distr2, par1, par2):
    """ Analytical value of the KL divergence for the given distributions.

    Parameters
    ----------    
    distr1, distr2 : str-s
                    Names of the distributions.
    par1, par2 : dictionary-s
                 Parameters of the distributions. If distr1 = distr2 =
                 'normal': par1["mean"], par1["cov"] and par2["mean"],
                 par2["cov"] are the means and the covariance matrices.

    Returns
    -------
    d : float
        Analytical value of the Kullback-Leibler divergence.

    """

    if distr1 == 'normal' and distr2 == 'normal':
        # covariance matrices, expectations:
        c1, m1 = par1['cov'], par1['mean']
        c2, m2 = par2['cov'], par2['mean']
        dim = len(m1)    

        invc2 = inv(c2)
        diffm = m1 - m2

        d = 1/2 * (log(det(c2)/det(c1)) + trace(dot(invc2, c1)) +
                   dot(diffm, dot(invc2, diffm)) - dim)
    else:
        raise Exception('Distribution=?')

    return d
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号