def saveAttention(input_sentence, attentions, outpath):
# Set up figure with colorbar
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
fig = plt.figure(figsize=(24,10), )
ax = fig.add_subplot(111)
cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')
fig.colorbar(cax)
if input_sentence:
# Set up axes
ax.set_yticklabels([' '] + list(input_sentence) + [' '])
# Show label at every tick
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.tight_layout()
plt.savefig(outpath)
plt.close('all')
评论列表
文章目录