visu.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:keras-toolbox 作者: hadim 项目源码 文件源码
def plot_all_weights(model, n=64, n_columns=3, **kwargs):
    """
    """
    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1 import make_axes_locatable

    # Set default matplotlib parameters
    if not 'interpolation' in kwargs.keys():
        kwargs['interpolation'] = "none"

    if not 'cmap' in kwargs.keys():
        kwargs['cmap'] = "gray"

    layers_to_show = []

    for i, layer in enumerate(model.layers):
        if hasattr(layer, "W"):
            weights = layer.W.get_value()
            if weights.ndim == 4:
                layers_to_show.append((i, layer))

    n_mosaic = len(layers_to_show)
    nrows = n_mosaic // n_columns
    ncols = n_columns

    if ncols ** 2 < n_mosaic:
        nrows +=1

    fig_w = 15
    fig_h = nrows * fig_w / ncols

    fig = plt.figure(figsize=(fig_w, fig_h))

    for i, (layer_id, layer) in enumerate(layers_to_show):

        mosaic = get_weights_mosaic(model, layer_id=layer_id, n=n)

        ax = fig.add_subplot(nrows, ncols, i+1)

        im = ax.imshow(mosaic, **kwargs)
        ax.set_title("Layer #{} called '{}' \nof type {}".format(layer_id, layer.name, layer.__class__.__name__))

        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.1)
        plt.colorbar(im, cax=cax)

        ax.axis('off')

        for sp in ax.spines.values():
            sp.set_visible(False)
        if ax.is_first_row():
            ax.spines['top'].set_visible(True)
        if ax.is_last_row():
            ax.spines['bottom'].set_visible(True)
        if ax.is_first_col():
            ax.spines['left'].set_visible(True)
        if ax.is_last_col():
            ax.spines['right'].set_visible(True)

    #fig.tight_layout()
    return fig
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号