attention_sum_reader.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:tensorflow-extenteten 作者: raviqqe 项目源码 文件源码
def _sum_attentions(attentions, document):
    assert static_rank(attentions) == 2 and static_rank(document) == 2

    num_entities = tf.reduce_max(document) + 1

    @func_scope()
    def _sum_attention(args):
        attentions, document = args
        assert static_rank(attentions) == 1 and static_rank(document) == 1
        return tf.unsorted_segment_sum(attentions, document, num_entities)

    attentions = tf.map_fn(_sum_attention,
                           [attentions, document],
                           dtype=FLAGS.float_type)

    return attentions[:, FLAGS.first_entity_index:FLAGS.last_entity_index + 1]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号