def plot(samples, figId=None, retBytes=False, shape=None):
if figId is None:
fig = plt.figure(figsize=(4, 4))
else:
fig = plt.figure(figId, figsize=(4,4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
if shape and shape[2] == 3:
rescaled = np.clip(sample, 0.0, 1.0)
plt.imshow(rescaled.reshape(*shape))
else:
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
if retBytes:
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
return fig, buf
return fig
评论列表
文章目录