def vis_activation(self, tweet, grads=False, activations=True, over_words=True, over_units=False):
pred = self.predict(tweet)
act_grad_matrix, layer_labels, text_labels = self._get_activations_gradients(tweet, grads,
activations, over_words, over_units)
plt.figure(figsize=(14,4))
cmap = sns.diverging_palette(220, 20, n=7)
ax = sns.heatmap(act_grad_matrix, xticklabels=text_labels, yticklabels=layer_labels, cmap=cmap)
ax.xaxis.tick_top()
plt.yticks(rotation=0)
plt.xticks(rotation=90)
plt.title('Score:%s'%pred['score'].values[0])
plt.show()
attention_vis.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录