def _sum_attentions(attentions, document):
assert static_rank(attentions) == 2 and static_rank(document) == 2
num_entities = tf.reduce_max(document) + 1
@func_scope()
def _sum_attention(args):
attentions, document = args
assert static_rank(attentions) == 1 and static_rank(document) == 1
return tf.unsorted_segment_sum(attentions, document, num_entities)
attentions = tf.map_fn(_sum_attention,
[attentions, document],
dtype=FLAGS.float_type)
return attentions[:, FLAGS.first_entity_index:FLAGS.last_entity_index + 1]
attention_sum_reader.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录