lnn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:lnn 作者: wgao9 项目源码 文件源码
def _KSG_mi(data,split,k=5):
    '''
        Estimate the mutual information I(X;Y) from samples {x_i,y_i}_{i=1}^N
        Using KSG mutual information estimator

        Input: data: 2D list of size N*(d_x + d_y)
        split: should be d_x, splitting the data into two parts, X and Y
        k: k-nearest neighbor parameter

        Output: one number of I(X;Y)
    '''
    assert split >=1, "x must have at least one dimension"
    assert split <= len(data[0]) - 1, "y must have at least one dimension"
    N = len(data)
    x = data[:,:split]
    y = data[:,split:]
    dx = len(x[0])      
    dy = len(y[0])

    tree_xy = ss.cKDTree(data)
    tree_x = ss.cKDTree(x)
    tree_y = ss.cKDTree(y)

    knn_dis = [tree_xy.query(point,k+1,p=2)[0][k] for point in data]
    ans = digamma(k) + log(N) + vd(dx,2) + vd(dy,2) - vd(dx+dy,2)
    for i in range(N):
        ans += -log(len(tree_y.query_ball_point(y[i],knn_dis[i],p=2))-1)/N - log(len(tree_x.query_ball_point(x[i],knn_dis[i],p=2))-1)/N

    return ans
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号