def analytical_value_d_kullback_leibler(distr1, distr2, par1, par2):
""" Analytical value of the KL divergence for the given distributions.
Parameters
----------
distr1, distr2 : str-s
Names of the distributions.
par1, par2 : dictionary-s
Parameters of the distributions. If distr1 = distr2 =
'normal': par1["mean"], par1["cov"] and par2["mean"],
par2["cov"] are the means and the covariance matrices.
Returns
-------
d : float
Analytical value of the Kullback-Leibler divergence.
"""
if distr1 == 'normal' and distr2 == 'normal':
# covariance matrices, expectations:
c1, m1 = par1['cov'], par1['mean']
c2, m2 = par2['cov'], par2['mean']
dim = len(m1)
invc2 = inv(c2)
diffm = m1 - m2
d = 1/2 * (log(det(c2)/det(c1)) + trace(dot(invc2, c1)) +
dot(diffm, dot(invc2, diffm)) - dim)
else:
raise Exception('Distribution=?')
return d
x_analytical_values.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录