def show_fit(filename, ax=None):
"""Util function. Show a gaussian fit on nearest distances.
Usage example:
np.mean([show_fit(f) for f in list_naive_npy])
"""
from sklearn.mixture import GMM
X = np.load("{}".format(filename))
dist2nearest = np.array(X).reshape(-1, 1)
if dist2nearest.shape[0] < 2:
print("Cannot fit a Gaussian with two distances.")
return
dist2nearest_2 = -(np.array(sorted(dist2nearest)).reshape(-1, 1))
dist2nearest = np.array(list(dist2nearest_2) +
list(dist2nearest)).reshape(-1, 1)
gmm = GMM(n_components=3)
gmm.fit(dist2nearest)
plt.hist(dist2nearest, bins=50, normed=True)
linspace = np.linspace(-1, 1, 1000)[:, np.newaxis]
plt.plot(linspace, np.exp(gmm.score_samples(linspace)[0]), 'r')
lin = np.linspace(0, 1, 10000)[:, np.newaxis]
pred = gmm.predict(linspace)
argmax = np.argmax(gmm.means_)
idx = np.min(np.where(pred == argmax)[0])
plt.axvline(x=lin[idx], linestyle='--', color='r')
plt.show()
return lin[idx] # threshold
评论列表
文章目录