x_analytical_values.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:adversarial-variational-bayes 作者: gdikov 项目源码 文件源码
def analytical_value_c_cross_entropy(distr1, distr2, par1, par2):
    """ Analytical value of the cross-entropy for the given distributions.

    Parameters
    ----------    
    distr1, distr2 : str
                     Name of the distributions.
    par1, par2 : dictionaries
                 Parameters of the distribution. If distr1 = distr2 =
                 'normal': par1["mean"], par1["cov"] and par2["mean"],
                 par2["cov"] are the means and the covariance matrices.

    Returns
    -------
    c : float
        Analytical value of the cross-entropy.

    """

    if distr1 == 'normal' and distr2 == 'normal':
        # covariance matrices, expectations:
        c1, m1 = par1['cov'], par1['mean']
        c2, m2 = par2['cov'], par2['mean']
        dim = len(m1)

        invc2 = inv(c2)
        diffm = m1 - m2

        c = 1/2 * (dim * log(2*pi) + log(det(c2)) + trace(dot(invc2, c1)) +
                   dot(diffm, dot(invc2, diffm)))
    else:
        raise Exception('Distribution=?')

    return c
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号