plotting.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:bptd 作者: aschein 项目源码 文件源码
def plot_component(Y_NN, Lambda_CC, s_Theta_CN, r_Theta_CN, assignments_N,
                   scale_func=lambda x: np.log(x + 1), filename=None, figsize=None, dpi=None):
    plt.figure(figsize=figsize, dpi=dpi)

    fontsize = 8
    height_ratios = [4, 1]
    width_ratios = [1, 4]
    N = Y_NN.shape[0]

    gs = gridspec.GridSpec(2, 2, height_ratios=height_ratios, width_ratios=width_ratios)
    gs.update(wspace=0.025, hspace=0.025)
    ax1 = plt.subplot(gs[1, 0])  # Lambda
    ax2 = plt.subplot(gs[0, 0])  # s_Theta
    ax3 = plt.subplot(gs[1, 1])  # r_Theta
    ax4 = plt.subplot(gs[0, 1])  # Y

    sns.heatmap(scale_func(Lambda_CC), vmin=0, cmap='Reds', ax=ax1, cbar=False,
                xticklabels=range(1, C + 1), yticklabels=range(1, C + 1))
    plt.setp(ax1.get_yticklabels(), fontsize=fontsize, weight='bold')
    plt.setp(ax1.get_xticklabels(), fontsize=fontsize, weight='bold')

    sns.heatmap(scale_func(s_Theta_CN.T), ax=ax2, vmin=0, cmap='Blues',
                yticklabels=actors[order_N], cbar=False)
    plt.setp(ax2.get_yticklabels(), fontsize=fontsize, rotation=0, weight='bold')
    ax2.set_xticklabels([])

    sns.heatmap(scale_func(r_Theta_CN), ax=ax3, vmin=0, cmap='Blues',
                xticklabels=actors[order_N], cbar=False)
    plt.setp(ax3.get_xticklabels(), fontsize=fontsize, rotation=90, weight='bold')
    ax3.set_yticklabels([])

    sns.heatmap(scale_func(Y_NN), ax=ax4, vmin=0, cmap='Reds', cbar=False)

    N = assignments_N.size
    last_assignment = assignments_N[0]
    for i, assignment in enumerate(assignments_N):
        if assignment != last_assignment:
            ax4.axvline(i, c='g', lw=2.)
            ax4.axhline(N - i, c='g', lw=2.)
            ax2.axhline(N - i, c='g', lw=2.)
            ax3.axvline(i, c='g', lw=2.)
            last_assignment = assignment
    ax4.set_xticklabels([])
    ax4.set_yticklabels([])

    if filename is not None:
        plt.savefig(filename, format='pdf', bbox_inches='tight')
    else:
        plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号