def show_heatmap(x, y, attention):
#print attention[:len(y),:len(x)]
#print attention[:len(y),:len(x)].shape
#data = np.transpose(attention[:len(y),:len(x)])
data = attention[:len(y),:len(x)]
x, y = y, x
#ax = plt.axes(aspect=0.4)
ax = plt.axes()
heatmap = plt.pcolor(data, cmap=plt.cm.Blues)
xticks = np.arange(len(y)) + 0.5
xlabels = y
yticks = np.arange(len(x)) + 0.5
ylabels = x
plt.xticks(xticks, xlabels, rotation='vertical')
ax.set_yticks(yticks)
ax.set_yticklabels(ylabels)
# make it look less like a scatter plot and more like a colored table
ax.tick_params(axis='both', length=0)
ax.invert_yaxis()
ax.xaxis.tick_top()
plt.colorbar(heatmap)
plt.show()
#plt.savefig('./attention-out.pdf')
pay_attention.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录