x_analytical_values.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:adversarial-variational-bayes 作者: gdikov 项目源码 文件源码
def analytical_value_k_prob_product(distr1, distr2, rho, par1, par2):
    """ Analytical value of the probability product kernel.

    Parameters
    ----------    
    distr1, distr2 : str
                     Name of the distributions.
    rho: float, >0
         Parameter of the probability product kernel.
    par1, par2 : dictionary-s
                 Parameters of the distributions. If distr1 = distr2 = 
                 'normal': par1["mean"], par1["cov"] and par2["mean"], 
                 par2["cov"] are the means and the covariance matrices.

    Returns
    -------
    k : float
         Analytical value of the probability product kernel.

    """

    if distr1 == 'normal' and distr2 == 'normal':
        # covariance matrices, expectations:
        c1, m1 = par1['cov'], par1['mean']
        c2, m2 = par2['cov'], par2['mean']
        dim = len(m1)

        # inv1, inv2, inv12:
        inv1, inv2 = inv(c1), inv(c2)
        inv12 = inv(inv1+inv2)

        m12 = dot(inv1, m1) + dot(inv2, m2)
        exp_arg = \
            dot(m1, dot(inv1, m1)) + dot(m2, dot(inv2, m2)) -\
            dot(m12, dot(inv12, m12))

        k = (2 * pi)**((1 - 2 * rho) * dim / 2) * rho**(-dim / 2) *\
            absolute(det(inv12))**(1 / 2) * \
            absolute(det(c1))**(-rho / 2) * \
            absolute(det(c2))**(-rho / 2) * exp(-rho / 2 * exp_arg)
    else:
        raise Exception('Distribution=?')

    return k
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号