def get_plot_buf(x, clusters, mu, logstd, true_mu, true_logstd):
N = x.shape[0]
K = mu.shape[0]
fig = plt.figure()
# print(clusters.shape)
# print(x.shape)
ax = fig.add_subplot(111, aspect='auto')
plt.scatter(x[:, 0], x[:, 1], c=clusters, s=50)
# print(mu, logstd)
ells = [Ellipse(xy=mean_, width=6*np.exp(logstd_[0]), height=6*np.exp(logstd_[1]),
angle=0, facecolor='none', zorder=10, edgecolor='g', label='predict' if i==0 else None)
for i, (mean_, logstd_) in enumerate(zip(mu, logstd))]
true_ells = [Ellipse(xy=mean_, width=6*np.exp(logstd_[0]), height=6*np.exp(logstd_[1]),
angle=0, facecolor='none', zorder=10, edgecolor='r', label='true' if i==0 else None)
for i,(mean_, logstd_) in enumerate(zip(true_mu, true_logstd))]
# print(ells[0])
[ax.add_patch(ell) for ell in ells]
[ax.add_patch(true_ell) for true_ell in true_ells]
ax.legend(loc='best')
ax.set_title('N={},K={}'.format(N, K))
plt.autoscale(True)
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close()
buf.seek(0)
return buf
评论列表
文章目录