image.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:modl 作者: arthurmensch 项目源码 文件源码
def plot_patches(fig, patches):
    if patches.ndim == 4:
        channel_step = patches.shape[3] // 3
        # patches = np.concatenate([np.sum(patches[:, :, :, i * channel_step:
        # (i + 1) * channel_step],
        #                                  axis=3)[..., np.newaxis]
        #                           for i in range(3)], axis=3)
        if patches.shape[3] == 1:
            patches = patches[:, :, :, 0]
        elif patches.shape[3] >= 3:
            patches = patches[:, :, :, :3]
            patches = np.rollaxis(patches, 3, 2).reshape(
                (patches.shape[0], patches.shape[1], patches.shape[2] * 3))
    patches = patches[:256]
    side_size =ceil(sqrt(patches.shape[0]))
    for i, patch in enumerate(patches):
        ax = fig.add_subplot(side_size, side_size, i + 1)
        ax.imshow(
            patch,
            interpolation='nearest')
        ax.set_xticks(())
        ax.set_yticks(())

    fig.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
    return fig
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号