evaluate_threshold.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:icing 作者: slipguru 项目源码 文件源码
def show_fit(filename, ax=None):
    """Util function. Show a gaussian fit on nearest distances.

    Usage example:
        np.mean([show_fit(f) for f in list_naive_npy])
    """
    from sklearn.mixture import GMM
    X = np.load("{}".format(filename))
    dist2nearest = np.array(X).reshape(-1, 1)
    if dist2nearest.shape[0] < 2:
        print("Cannot fit a Gaussian with two distances.")
        return

    dist2nearest_2 = -(np.array(sorted(dist2nearest)).reshape(-1, 1))
    dist2nearest = np.array(list(dist2nearest_2) +
                            list(dist2nearest)).reshape(-1, 1)
    gmm = GMM(n_components=3)
    gmm.fit(dist2nearest)

    plt.hist(dist2nearest, bins=50, normed=True)
    linspace = np.linspace(-1, 1, 1000)[:, np.newaxis]
    plt.plot(linspace, np.exp(gmm.score_samples(linspace)[0]), 'r')

    lin = np.linspace(0, 1, 10000)[:, np.newaxis]
    pred = gmm.predict(linspace)
    argmax = np.argmax(gmm.means_)
    idx = np.min(np.where(pred == argmax)[0])
    plt.axvline(x=lin[idx], linestyle='--', color='r')
    plt.show()
    return lin[idx]  # threshold
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号