def get_mention_emb(self, text_emb, text_outputs, mention_starts, mention_ends):
mention_emb_list = []
mention_start_emb = tf.gather(text_outputs, mention_starts) # [num_mentions, emb]
mention_emb_list.append(mention_start_emb)
mention_end_emb = tf.gather(text_outputs, mention_ends) # [num_mentions, emb]
mention_emb_list.append(mention_end_emb)
mention_width = 1 + mention_ends - mention_starts # [num_mentions]
if self.config["use_features"]:
mention_width_index = mention_width - 1 # [num_mentions]
mention_width_emb = tf.gather(tf.get_variable("mention_width_embeddings", [self.config["max_mention_width"], self.config["feature_size"]]), mention_width_index) # [num_mentions, emb]
mention_width_emb = tf.nn.dropout(mention_width_emb, self.dropout)
mention_emb_list.append(mention_width_emb)
if self.config["model_heads"]:
mention_indices = tf.expand_dims(tf.range(self.config["max_mention_width"]), 0) + tf.expand_dims(mention_starts, 1) # [num_mentions, max_mention_width]
mention_indices = tf.minimum(util.shape(text_outputs, 0) - 1, mention_indices) # [num_mentions, max_mention_width]
mention_text_emb = tf.gather(text_emb, mention_indices) # [num_mentions, max_mention_width, emb]
self.head_scores = util.projection(text_outputs, 1) # [num_words, 1]
mention_head_scores = tf.gather(self.head_scores, mention_indices) # [num_mentions, max_mention_width, 1]
mention_mask = tf.expand_dims(tf.sequence_mask(mention_width, self.config["max_mention_width"], dtype=tf.float32), 2) # [num_mentions, max_mention_width, 1]
mention_attention = tf.nn.softmax(mention_head_scores + tf.log(mention_mask), dim=1) # [num_mentions, max_mention_width, 1]
mention_head_emb = tf.reduce_sum(mention_attention * mention_text_emb, 1) # [num_mentions, emb]
mention_emb_list.append(mention_head_emb)
mention_emb = tf.concat(mention_emb_list, 1) # [num_mentions, emb]
return mention_emb
评论列表
文章目录