def printNetwork_weights(prototxt_filename, caffemodel_filename):
'''
For each CNN layer, print weight heatmap and weight histogram
'''
net = caffe.Net(prototxt_filename,caffemodel_filename, caffe.TEST)
for layerName in net.params:
# display the weights
arr = net.params[layerName][0].data
plt.clf()
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111)
cax = ax.matshow(arr, interpolation='none')
fig.colorbar(cax, orientation="horizontal")
plt.savefig('{0}_weights_{1}.png'.format(caffemodel_filename, layerName), dpi=100, format='png', bbox_inches='tight') # use format='svg' or 'pdf' for vectorial pictures
plt.close()
# weights histogram
plt.clf()
plt.hist(arr.tolist(), bins=20)
plt.savefig('{0}_weights_hist_{1}.png'.format(caffemodel_filename, layerName), dpi=100, format='png', bbox_inches='tight') # use format='svg' or 'pdf' for vectorial pictures
plt.close()
评论列表
文章目录