def weight(repo, compare_repo=None):
repo_saving, compare_saving, compare_name = _combine_repo(repo, compare_repo, 'trainable_params')
def show(attr, old, new):
layername = new
g = [[]]
for i, (wtname, wt) in enumerate(repo_saving[layername].items()):
if wt.ndim == 4:
wt.transpose(0, 2, 1, 3)
wt.reshape([wt.shape[:2], wt.shape[2:]])
p = plt.figure(plot_width=100, plot_height=100, title=wtname, tools=[])
if wt.ndim == 1:
p.line(range(wt.shape[0]), wt, line_width=2, color='black')
else:
p.image([wt], [0], [0], [p.x_range[-1]], [p.y_range[-1]])
g[-1].append(p)
if (i + 1) % 5 == 0:
g.append([])
v = lyt.gridplot(g, toolbar_location="below", merge_tools=True)
plt.output_file('./weightplot.html', title='WeightPlot')
select = models.widgets.Select(title="Layer:", value=repo_saving.keys()[0], options=repo_saving.keys())
select.on_change("value", show)
plt.save(lyt.widgetbox(select))