def _create_figure(predictions_dict):
"""Creates and returns a new figure that visualizes
attention scores for for a single model predictions.
"""
# Find out how long the predicted sequence is
target_words = list(predictions_dict["predicted_tokens"])
prediction_len = _get_prediction_length(predictions_dict)
# Get source words
source_len = predictions_dict["features.source_len"]
source_words = predictions_dict["features.source_tokens"][:source_len]
# Plot
fig = plt.figure(figsize=(8, 8))
plt.imshow(
X=predictions_dict["attention_scores"][:prediction_len, :source_len],
interpolation="nearest",
cmap=plt.cm.Blues)
plt.xticks(np.arange(source_len), source_words, rotation=45)
plt.yticks(np.arange(prediction_len), target_words, rotation=-45)
fig.tight_layout()
return fig
评论列表
文章目录