lnn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:lnn 作者: wgao9 项目源码 文件源码
def _3LNN_1_KSG_mi(data,split,k=5,tr=30):
    '''
        Estimate the mutual information I(X;Y) from samples {x_i,y_i}_{i=1}^N
        Using I(X;Y) = H_{LNN}(X) + H_{LNN}(Y) - H_{LNN}(X;Y) with "KSG trick"
        where H_{LNN} is the LNN entropy estimator with order 1

        Input: data: 2D list of size N*(d_x + d_y)
        split: should be d_x, splitting the data into two parts, X and Y
        k: k-nearest neighbor parameter
        tr: number of sample used for computation

        Output: one number of I(X;Y)
    '''
    assert split >=1, "x must have at least one dimension"
    assert split <= len(data[0]) - 1, "y must have at least one dimension"
    x = data[:,:split]
    y = data[:,split:]

    tree_xy = ss.cKDTree(data)
    knn_dis = [tree_xy.query(point,k+1,p=2)[0][k] for point in data]
    return LNN_1_entropy(x,k,tr,bw=knn_dis) + LNN_1_entropy(y,k,tr,bw=knn_dis) - LNN_1_entropy(data,k,tr,bw=knn_dis)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号