def plot(samples):
width = min(12,int(np.sqrt(len(samples))))
fig = plt.figure(figsize=(width, width))
gs = gridspec.GridSpec(width, width)
gs.update(wspace=0.05, hspace=0.05)
for ind, sample in enumerate(samples):
if ind >= width*width:
break
ax = plt.subplot(gs[ind])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
sample = sample * 0.5 + 0.5
sample = np.transpose(sample, (1, 2, 0))
plt.imshow(sample)
return fig
评论列表
文章目录