coref_model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:e2e-coref 作者: kentonl 项目源码 文件源码
def get_mention_emb(self, text_emb, text_outputs, mention_starts, mention_ends):
    mention_emb_list = []

    mention_start_emb = tf.gather(text_outputs, mention_starts) # [num_mentions, emb]
    mention_emb_list.append(mention_start_emb)

    mention_end_emb = tf.gather(text_outputs, mention_ends) # [num_mentions, emb]
    mention_emb_list.append(mention_end_emb)

    mention_width = 1 + mention_ends - mention_starts # [num_mentions]
    if self.config["use_features"]:
      mention_width_index = mention_width - 1 # [num_mentions]
      mention_width_emb = tf.gather(tf.get_variable("mention_width_embeddings", [self.config["max_mention_width"], self.config["feature_size"]]), mention_width_index) # [num_mentions, emb]
      mention_width_emb = tf.nn.dropout(mention_width_emb, self.dropout)
      mention_emb_list.append(mention_width_emb)

    if self.config["model_heads"]:
      mention_indices = tf.expand_dims(tf.range(self.config["max_mention_width"]), 0) + tf.expand_dims(mention_starts, 1) # [num_mentions, max_mention_width]
      mention_indices = tf.minimum(util.shape(text_outputs, 0) - 1, mention_indices) # [num_mentions, max_mention_width]
      mention_text_emb = tf.gather(text_emb, mention_indices) # [num_mentions, max_mention_width, emb]
      self.head_scores = util.projection(text_outputs, 1) # [num_words, 1]
      mention_head_scores = tf.gather(self.head_scores, mention_indices) # [num_mentions, max_mention_width, 1]
      mention_mask = tf.expand_dims(tf.sequence_mask(mention_width, self.config["max_mention_width"], dtype=tf.float32), 2) # [num_mentions, max_mention_width, 1]
      mention_attention = tf.nn.softmax(mention_head_scores + tf.log(mention_mask), dim=1) # [num_mentions, max_mention_width, 1]
      mention_head_emb = tf.reduce_sum(mention_attention * mention_text_emb, 1) # [num_mentions, emb]
      mention_emb_list.append(mention_head_emb)

    mention_emb = tf.concat(mention_emb_list, 1) # [num_mentions, emb]
    return mention_emb
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号