def _KSG_mi(data,split,k=5):
'''
Estimate the mutual information I(X;Y) from samples {x_i,y_i}_{i=1}^N
Using KSG mutual information estimator
Input: data: 2D list of size N*(d_x + d_y)
split: should be d_x, splitting the data into two parts, X and Y
k: k-nearest neighbor parameter
Output: one number of I(X;Y)
'''
assert split >=1, "x must have at least one dimension"
assert split <= len(data[0]) - 1, "y must have at least one dimension"
N = len(data)
x = data[:,:split]
y = data[:,split:]
dx = len(x[0])
dy = len(y[0])
tree_xy = ss.cKDTree(data)
tree_x = ss.cKDTree(x)
tree_y = ss.cKDTree(y)
knn_dis = [tree_xy.query(point,k+1,p=2)[0][k] for point in data]
ans = digamma(k) + log(N) + vd(dx,2) + vd(dy,2) - vd(dx+dy,2)
for i in range(N):
ans += -log(len(tree_y.query_ball_point(y[i],knn_dis[i],p=2))-1)/N - log(len(tree_x.query_ball_point(x[i],knn_dis[i],p=2))-1)/N
return ans
评论列表
文章目录