test_completely_positive.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:sdp_kmeans 作者: simonsfoundation 项目源码 文件源码
def test_reconstruction(X, gt, n_clusters, filename, from_file=False):
    Ds = sdp_kmeans(X, n_clusters, method='cvx')

    if from_file:
        data = scipy.io.loadmat('{}{}.mat'.format(dir_name, filename))
        rec_errors = data['rec_errors']
        k_values = data['k_values']
    else:
        k_values = np.arange(200 + len(X)) + 1
        rec_errors = []
        for k in k_values:
            print('{} / {}'.format(k, k_values[-1]))
            rec_errors_k = []
            for trials in range(50):
                Y = symnmf_admm(Ds[-1], k=k)
                rec_errors_k.append(check_completely_positivity(Ds[-1], Y))
            rec_errors.append(rec_errors_k)
        rec_errors = np.array(rec_errors)
        scipy.io.savemat('{}{}.mat'.format(dir_name, filename),
                         dict(rec_errors=rec_errors,
                              k_values=k_values))

    sns.set_style('white')

    plt.figure(tight_layout=True)
    gs = gridspec.GridSpec(1, 3)

    ax = plt.subplot(gs[0])
    plot_data_clustered(X, gt, ax=ax)

    for i, D_input in enumerate(Ds):
        ax = plt.subplot(gs[i + 1])
        plot_matrix(D_input, ax=ax)
        if i == 0:
            ax.set_title('Original Gramian')
        else:
            ax.set_title('Layer {} (k={})'.format(i, n_clusters))
    plt.savefig('{}{}_solution.pdf'.format(dir_name, filename))

    plt.figure(tight_layout=True)
    mean = np.mean(rec_errors, axis=1)
    std = np.std(rec_errors, axis=1)
    sns.set_palette('muted')
    plt.fill_between(np.squeeze(k_values), mean - 2 * std, mean + 2 * std,
                     alpha=0.3)
    plt.semilogy(np.squeeze(k_values), mean, linewidth=2)
    plt.semilogy([n_clusters, n_clusters], [mean.min(), mean.max()],
                 linestyle='--', linewidth=2)
    plt.xlabel('$r$', size='xx-large')
    plt.ylabel('Relative reconstruction error', size='xx-large')
    plt.ylim(np.floor(rec_errors.min() * 1e3) / 1e3, 1)
    plt.savefig('{}{}_curve.pdf'.format(dir_name, filename))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号