common.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:tf-exercise-gan 作者: sanghoon 项目源码 文件源码
def plot(samples, figId=None, retBytes=False, shape=None):
    if figId is None:
        fig = plt.figure(figsize=(4, 4))
    else:
        fig = plt.figure(figId, figsize=(4,4))

    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        if shape and shape[2] == 3:
            rescaled = np.clip(sample, 0.0, 1.0)
            plt.imshow(rescaled.reshape(*shape))
        else:
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    if retBytes:
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        return fig, buf

    return fig
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号