pay_attention.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:almond-nnparser 作者: Stanford-Mobisocial-IoT-Lab 项目源码 文件源码
def show_heatmap(x, y, attention):
    #print attention[:len(y),:len(x)]
    #print attention[:len(y),:len(x)].shape
    #data = np.transpose(attention[:len(y),:len(x)])
    data = attention[:len(y),:len(x)]
    x, y = y, x

    #ax = plt.axes(aspect=0.4)
    ax = plt.axes()
    heatmap = plt.pcolor(data, cmap=plt.cm.Blues)

    xticks = np.arange(len(y)) + 0.5
    xlabels = y
    yticks = np.arange(len(x)) + 0.5
    ylabels = x
    plt.xticks(xticks, xlabels, rotation='vertical')
    ax.set_yticks(yticks)
    ax.set_yticklabels(ylabels)

    # make it look less like a scatter plot and more like a colored table
    ax.tick_params(axis='both', length=0)
    ax.invert_yaxis()
    ax.xaxis.tick_top()

    plt.colorbar(heatmap)

    plt.show()
    #plt.savefig('./attention-out.pdf')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号