def __plot_scatter(prl, max_points=None, fileout="PR_scatter.png", title="Precision Recall Scatter Plot"):
"""
:param prl: list of tuples (precision, recall)
:param max_points: max number of tuples to plot
:param fileout: output filename
:param title: plot title
"""
prs = [i[0] for i in prl]
recs = [i[1] for i in prl]
if max_points is not None:
prs = prs[:max_points]
recs = recs[:max_points]
xy = np.vstack([prs, recs])
z = gaussian_kde(xy)(xy)
x = np.array(prs)
y = np.array(recs)
base = min(z)
rg = max(z) - base
z = np.array(z)
idx = z.argsort()
x, y, z = x[idx], y[idx], (z[idx] - base) / rg
fig, ax = plt.subplots()
sca = ax.scatter(x, y, c=z, s=50, edgecolor='', cmap=plt.cm.jet)
fig.colorbar(sca)
plt.ylabel("Recall", fontsize=20, labelpad=15)
plt.xlabel("Precision", fontsize=20)
plt.ylim([-0.01, 1.01])
plt.xlim([-0.01, 1.01])
plt.title(title)
if matplotlib.get_backend().lower() in ['agg', 'macosx']:
fig.set_tight_layout(True)
else:
fig.tight_layout()
plt.savefig("%s" % fileout)
评论列表
文章目录