plots.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:cgpm 作者: probcomp 项目源码 文件源码
def plot_heatmap(
        D, xordering=None, yordering=None, xticklabels=None,
        yticklabels=None, vmin=None, vmax=None, ax=None):
    import seaborn as sns
    D = np.copy(D)

    if ax is None:
        _, ax = plt.subplots()
    if xticklabels is None:
        xticklabels = np.arange(D.shape[0])
    if yticklabels is None:
        yticklabels = np.arange(D.shape[1])
    if xordering is not None:
        xticklabels = xticklabels[xordering]
        D = D[:,xordering]
    if yordering is not None:
        yticklabels = yticklabels[yordering]
        D = D[yordering,:]

    sns.heatmap(
        D, yticklabels=yticklabels, xticklabels=xticklabels,
        linewidths=0.2, cmap='BuGn', ax=ax, vmin=vmin, vmax=vmax)
    ax.set_xticklabels(xticklabels, rotation=90)
    ax.set_yticklabels(yticklabels, rotation=0)
    return ax
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号