plotter.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:NNBuilder 作者: aeloyq 项目源码 文件源码
def weight(repo, compare_repo=None):
    repo_saving, compare_saving, compare_name = _combine_repo(repo, compare_repo, 'trainable_params')

    def show(attr, old, new):
        layername = new
        g = [[]]
        for i, (wtname, wt) in enumerate(repo_saving[layername].items()):
            if wt.ndim == 4:
                wt.transpose(0, 2, 1, 3)
                wt.reshape([wt.shape[:2], wt.shape[2:]])
            p = plt.figure(plot_width=100, plot_height=100, title=wtname, tools=[])
            if wt.ndim == 1:
                p.line(range(wt.shape[0]), wt, line_width=2, color='black')
            else:
                p.image([wt], [0], [0], [p.x_range[-1]], [p.y_range[-1]])
            g[-1].append(p)
            if (i + 1) % 5 == 0:
                g.append([])
        v = lyt.gridplot(g, toolbar_location="below", merge_tools=True)

    plt.output_file('./weightplot.html', title='WeightPlot')
    select = models.widgets.Select(title="Layer:", value=repo_saving.keys()[0], options=repo_saving.keys())
    select.on_change("value", show)
    plt.save(lyt.widgetbox(select))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号