def heatmap(sources, refs, trans, actions, idx, atten=None, savefig=True, name='test', info=None, show=False):
source = [s.strip() for s in sources[idx].decode('utf8').replace('@@', '--').split()] + ['||']
target = ['*'] + [s.strip() for s in trans[idx].decode('utf8').replace('@@', '--').split()] + ['||']
action = actions[idx]
if atten:
attention = numpy.array(atten[idx])
def track(acts, data, annote):
x, y = 0, 0
for a in acts:
x += a
y += 1 - a
# print a, x, y, target[x].encode('utf8')
data[y, x] = 1
annote[y, x] = 'W' if a == 0 else 'C'
return data, annote
# print target
data = numpy.zeros((len(source), len(target)))
annote = numpy.chararray(data.shape, itemsize=8)
annote[:] = ''
data, annote = track(action, data, annote)
data[0, 0] = 1
annote[0, 0] = 'S'
if atten:
data[:-1, 1:] += attention.T
d = pd.DataFrame(data=data, columns=target, index=source)
# p = sns.diverging_palette(220, 10, as_cmap=True)
f, ax = plot.subplots(figsize=(11, 11))
f.set_canvas(plot.gcf().canvas)
g = sns.heatmap(d, ax=ax, annot=annote, fmt='s')
g.xaxis.tick_top()
plot.xticks(rotation=90)
plot.yticks(rotation=0)
# plot.show()
if savefig:
if not os.path.exists('.images/C_{}'.format(name)):
os.mkdir('.images/C_{}'.format(name))
filename = 'Idx={}||'.format(info['index'])
for w in info:
if w is not 'index':
filename += '.{}={:.2f}'.format(w, float(info[w]))
print 'saving...'
f.savefig('.images/C_{}'.format(name) + '/{}'.format(filename) + '.pdf', dpi=100)
if show:
plot.show()
print 'plotting done.'
plot.close()
评论列表
文章目录