insepection.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:NMT-RDPG 作者: MultiPath 项目源码 文件源码
def heatmap(sources, refs, trans, actions, idx, atten=None, savefig=True, name='test', info=None, show=False):
    source = [s.strip() for s in sources[idx].decode('utf8').replace('@@', '--').split()] + ['||']
    target = ['*'] + [s.strip() for s in trans[idx].decode('utf8').replace('@@', '--').split()] + ['||']
    action = actions[idx]


    if atten:
        attention = numpy.array(atten[idx])

    def track(acts, data, annote):
        x, y = 0, 0
        for a in acts:
            x += a
            y += 1 - a
            # print a, x, y, target[x].encode('utf8')
            data[y, x]   = 1
            annote[y, x] = 'W' if a == 0  else 'C'

        return data, annote
    # print target

    data       = numpy.zeros((len(source), len(target)))
    annote     = numpy.chararray(data.shape, itemsize=8)
    annote[:]  = '' 
    data, annote  = track(action, data, annote)
    data[0, 0]    = 1
    annote[0, 0]  = 'S'
    if atten:
        data[:-1, 1:] += attention.T

    d  = pd.DataFrame(data=data, columns=target, index=source)
    # p  = sns.diverging_palette(220, 10, as_cmap=True)
    f, ax = plot.subplots(figsize=(11, 11))
    f.set_canvas(plot.gcf().canvas)
    g  = sns.heatmap(d, ax=ax, annot=annote, fmt='s')
    g.xaxis.tick_top()

    plot.xticks(rotation=90)
    plot.yticks(rotation=0)
    # plot.show()
    if savefig:
        if not os.path.exists('.images/C_{}'.format(name)):
            os.mkdir('.images/C_{}'.format(name))

        filename = 'Idx={}||'.format(info['index'])
        for w in info:
            if w is not 'index':
                filename += '.{}={:.2f}'.format(w, float(info[w]))

        print 'saving...'
        f.savefig('.images/C_{}'.format(name) + '/{}'.format(filename) + '.pdf', dpi=100)
    if show:
        plot.show()

    print 'plotting done.'
    plot.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号