est_rel_entro_HJW.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:HJW_KL_divergence_estimator 作者: Mathegineer 项目源码 文件源码
def log_mat(x, n, g_coeff, c_1, const):
    with np.errstate(divide='ignore', invalid='ignore'):
        K = g_coeff.shape[0] - 1
        thres = 2 * c_1 * math.log(n) / n
        [T, X] = np.meshgrid(thres, x)
        ratio = np.clip(2*X/T - 1, 0, 1)
        # force MATLAB-esque behavior with NaN, inf
        ratio[T == 0] = 1.0
        ratio[X == 0] = 0.0
        q = np.reshape(np.arange(K), [1, 1, K])
        g = np.tile(np.reshape(g_coeff, [1, 1, K + 1]), [c_1.shape[1], 1])
        g[:, :, 0] = g[:, :, 0] + np.log(thres)
        MLE = np.log(X) + (1-X) / (2*X*n)
        MLE[X == 0] = -np.log(n) - const
        tmp = (n*X[:,:,np.newaxis] - q)/(T[:,:,np.newaxis]*(n - q))
        polyApp = np.sum(np.cumprod(np.dstack([np.ones(T.shape + (1,)), tmp]),
                                    axis=2) * g, axis=2)
        polyFail = np.logical_or(np.isnan(polyApp), np.isinf(polyApp))
        polyApp[polyFail] = MLE[polyFail]
        return ratio*MLE + (1-ratio)*polyApp
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号