dump_attention.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def _create_figure(predictions_dict):
  """Creates and returns a new figure that visualizes
  attention scores for for a single model predictions.
  """

  # Find out how long the predicted sequence is
  target_words = list(predictions_dict["predicted_tokens"])

  prediction_len = _get_prediction_length(predictions_dict)

  # Get source words
  source_len = predictions_dict["features.source_len"]
  source_words = predictions_dict["features.source_tokens"][:source_len]

  # Plot
  fig = plt.figure(figsize=(8, 8))
  plt.imshow(
      X=predictions_dict["attention_scores"][:prediction_len, :source_len],
      interpolation="nearest",
      cmap=plt.cm.Blues)
  plt.xticks(np.arange(source_len), source_words, rotation=45)
  plt.yticks(np.arange(prediction_len), target_words, rotation=-45)
  fig.tight_layout()

  return fig
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号