def plot_attention(attention_matrix: np.ndarray, source_tokens: List[str], target_tokens: List[str], filename: str):
"""
Uses matplotlib for creating a visualization of the attention matrix.
:param attention_matrix: The attention matrix.
:param source_tokens: A list of source tokens.
:param target_tokens: A list of target tokens.
:param filename: The file to which the attention visualization will be written to.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
assert attention_matrix.shape[0] == len(target_tokens)
plt.imshow(attention_matrix.transpose(), interpolation="nearest", cmap="Greys")
plt.xlabel("target")
plt.ylabel("source")
plt.gca().set_xticks([i for i in range(0, len(target_tokens))])
plt.gca().set_yticks([i for i in range(0, len(source_tokens))])
plt.gca().set_xticklabels(target_tokens, rotation='vertical')
plt.gca().set_yticklabels(source_tokens)
plt.tight_layout()
plt.savefig(filename)
logger.info("Saved alignment visualization to " + filename)
评论列表
文章目录