keras_mnist_vis.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:keras-mnist-workshop 作者: drschilling 项目源码 文件源码
def keras_digits_vis(model, X_test, y_test):

    layer_idx = utils.find_layer_idx(model, 'preds')
    model.layers[layer_idx].activation = activations.linear
    model = utils.apply_modifications(model)

    for class_idx in np.arange(10):    
        indices = np.where(y_test[:, class_idx] == 1.)[0]
        idx = indices[0]

        f, ax = plt.subplots(1, 4)
        ax[0].imshow(X_test[idx][..., 0])

        for i, modifier in enumerate([None, 'guided', 'relu']):
            heatmap = visualize_saliency(model, layer_idx, filter_indices=class_idx, 
                                        seed_input=X_test[idx], backprop_modifier=modifier)
            if modifier is None:
                modifier = 'vanilla'
            ax[i+1].set_title(modifier)    
            ax[i+1].imshow(heatmap)
    plt.imshow(heatmap)
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号