vi_GMM_2d.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:vi_vae_gmm 作者: wangg12 项目源码 文件源码
def get_plot_buf(x, clusters, mu, logstd, true_mu, true_logstd):
    N = x.shape[0]
    K = mu.shape[0]
    fig = plt.figure()
    # print(clusters.shape)
    # print(x.shape)
    ax = fig.add_subplot(111, aspect='auto')
    plt.scatter(x[:, 0], x[:, 1], c=clusters, s=50)
    # print(mu, logstd)
    ells = [Ellipse(xy=mean_, width=6*np.exp(logstd_[0]), height=6*np.exp(logstd_[1]),
                angle=0, facecolor='none', zorder=10, edgecolor='g', label='predict' if i==0 else None)
            for i, (mean_, logstd_) in enumerate(zip(mu, logstd))]
    true_ells = [Ellipse(xy=mean_, width=6*np.exp(logstd_[0]), height=6*np.exp(logstd_[1]),
                angle=0, facecolor='none', zorder=10, edgecolor='r', label='true' if i==0 else None)
            for i,(mean_, logstd_) in enumerate(zip(true_mu, true_logstd))]
    # print(ells[0])
    [ax.add_patch(ell) for ell in ells]
    [ax.add_patch(true_ell) for true_ell in true_ells]
    ax.legend(loc='best')
    ax.set_title('N={},K={}'.format(N, K))
    plt.autoscale(True)
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    return buf
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号