metrics.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:bnn-analysis 作者: myshkov 项目源码 文件源码
def kl_divergence(p_samples, q_samples):
    # estimate densities
    # p_samples = np.nan_to_num(p_samples)
    # q_samples = np.nan_to_num(q_samples)

    if isinstance(p_samples, tuple):
        idx, p_samples = p_samples

        if idx not in _cached_p_pdf:
            _cached_p_pdf[idx] = sc.gaussian_kde(p_samples)

        p_pdf = _cached_p_pdf[idx]
    else:
        p_pdf = sc.gaussian_kde(p_samples)

    q_pdf = sc.gaussian_kde(q_samples)

    # joint support
    left = min(min(p_samples), min(q_samples))
    right = max(max(p_samples), max(q_samples))

    p_samples_num = p_samples.shape[0]
    q_samples_num = q_samples.shape[0]

    # quantise
    lin = np.linspace(left, right, min(max(p_samples_num, q_samples_num), MAX_GRID_POINTS))
    p = p_pdf.pdf(lin)
    q = q_pdf.pdf(lin)

    # KL
    kl = min(sc.entropy(p, q), MAX_KL)

    return kl
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号