utils.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:sockeye 作者: awslabs 项目源码 文件源码
def plot_attention(attention_matrix: np.ndarray, source_tokens: List[str], target_tokens: List[str], filename: str):
    """
    Uses matplotlib for creating a visualization of the attention matrix.

    :param attention_matrix: The attention matrix.
    :param source_tokens: A list of source tokens.
    :param target_tokens: A list of target tokens.
    :param filename: The file to which the attention visualization will be written to.
    """
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    assert attention_matrix.shape[0] == len(target_tokens)

    plt.imshow(attention_matrix.transpose(), interpolation="nearest", cmap="Greys")
    plt.xlabel("target")
    plt.ylabel("source")
    plt.gca().set_xticks([i for i in range(0, len(target_tokens))])
    plt.gca().set_yticks([i for i in range(0, len(source_tokens))])
    plt.gca().set_xticklabels(target_tokens, rotation='vertical')
    plt.gca().set_yticklabels(source_tokens)
    plt.tight_layout()
    plt.savefig(filename)
    logger.info("Saved alignment visualization to " + filename)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号