def plot_network(image, model, label=None):
layer_names = [l.name for l in model.layers if isinstance(l,Conv2D)]
n_conv = len(layer_names)
n_axes = n_conv
prediction = model.predict(np.expand_dims(image,0))
mng = plt.get_current_fig_manager()
mng.full_screen_toggle()
fig, [axlist1, axlist2] = plt.subplots(2,n_conv)
diagnosis = ["negative", "positive"]
for j in range(n_conv):
plot_heatmap(image, model, layer_names[j],"abnormal",axlist1[j])
# axlist1[j].set_xlabel(layer_names[j] + "ab")
for j in range(n_conv):
plot_heatmap(image, model, layer_names[j],"normal",axlist2[j],cmap=plt.cm.inferno)
fig.suptitle("Prediction: {}, {}".format(prediction,label))
fig.show()
visualizations.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录