def plot_weight_matrix(Z, outname, save=True):
num = Z.shape[0]
fig = plt.figure(1, (80, 80))
fig.subplots_adjust(left=0.05, right=0.95)
grid = AxesGrid(fig, (1, 4, 2), # similar to subplot(142)
nrows_ncols=(int(np.ceil(num / 10.)), 10),
axes_pad=0.04,
share_all=True,
label_mode="L",
)
for i in range(num):
im = grid[i].imshow(Z[i, :, :, :].mean(
axis=0), cmap='gray')
for i in range(grid.ngrids):
grid[i].axis('off')
for cax in grid.cbar_axes:
cax.toggle_label(False)
if save:
fig.savefig(outname, bbox_inches='tight')
fig.clear()
评论列表
文章目录