def plot_all_weights(model, n=64, n_columns=3, **kwargs):
"""
"""
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
# Set default matplotlib parameters
if not 'interpolation' in kwargs.keys():
kwargs['interpolation'] = "none"
if not 'cmap' in kwargs.keys():
kwargs['cmap'] = "gray"
layers_to_show = []
for i, layer in enumerate(model.layers):
if hasattr(layer, "W"):
weights = K.eval(layer.W.value())
if weights.ndim == 4:
layers_to_show.append((i, layer))
n_mosaic = len(layers_to_show)
nrows = n_mosaic // n_columns
ncols = n_columns
if ncols ** 2 < n_mosaic:
nrows +=1
fig_w = 15
fig_h = nrows * fig_w / ncols
fig = plt.figure(figsize=(fig_w, fig_h))
for i, (layer_id, layer) in enumerate(layers_to_show):
mosaic = get_weights_mosaic(model, layer_id=layer_id, n=n)
ax = fig.add_subplot(nrows, ncols, i+1)
im = ax.imshow(mosaic, **kwargs)
ax.set_title("Layer #{} called '{}' \nof type {}".format(layer_id, layer.name, layer.__class__.__name__))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
plt.colorbar(im, cax=cax)
ax.axis('off')
for sp in ax.spines.values():
sp.set_visible(False)
if ax.is_first_row():
ax.spines['top'].set_visible(True)
if ax.is_last_row():
ax.spines['bottom'].set_visible(True)
if ax.is_first_col():
ax.spines['left'].set_visible(True)
if ax.is_last_col():
ax.spines['right'].set_visible(True)
#fig.tight_layout()
return fig
评论列表
文章目录