coref_model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:e2e-coref 作者: kentonl 项目源码 文件源码
def tensorize_example(self, example, is_training, oov_counts=None):
    clusters = example["clusters"]

    gold_mentions = sorted(tuple(m) for m in util.flatten(clusters))
    gold_mention_map = {m:i for i,m in enumerate(gold_mentions)}
    cluster_ids = np.zeros(len(gold_mentions))
    for cluster_id, cluster in enumerate(clusters):
      for mention in cluster:
        cluster_ids[gold_mention_map[tuple(mention)]] = cluster_id

    sentences = example["sentences"]
    num_words = sum(len(s) for s in sentences)
    speakers = util.flatten(example["speakers"])

    assert num_words == len(speakers)

    max_sentence_length = max(len(s) for s in sentences)
    max_word_length = max(max(max(len(w) for w in s) for s in sentences), max(self.config["filter_widths"]))
    word_emb = np.zeros([len(sentences), max_sentence_length, self.embedding_size])
    char_index = np.zeros([len(sentences), max_sentence_length, max_word_length])
    text_len = np.array([len(s) for s in sentences])
    for i, sentence in enumerate(sentences):
      for j, word in enumerate(sentence):
        current_dim = 0
        for k, (d, (s,l)) in enumerate(zip(self.embedding_dicts, self.embedding_info)):
          if l:
            current_word = word.lower()
          else:
            current_word = word
          if oov_counts is not None and current_word not in d:
            oov_counts[k] += 1
          word_emb[i, j, current_dim:current_dim + s] = util.normalize(d[current_word])
          current_dim += s
        char_index[i, j, :len(word)] = [self.char_dict[c] for c in word]

    speaker_dict = { s:i for i,s in enumerate(set(speakers)) }
    speaker_ids = np.array([speaker_dict[s] for s in speakers])

    doc_key = example["doc_key"]
    genre = self.genres[doc_key[:2]]

    gold_starts, gold_ends = self.tensorize_mentions(gold_mentions)

    if is_training and len(sentences) > self.config["max_training_sentences"]:
      return self.truncate_example(word_emb, char_index, text_len, speaker_ids, genre, is_training, gold_starts, gold_ends, cluster_ids)
    else:
      return word_emb, char_index, text_len, speaker_ids, genre, is_training, gold_starts, gold_ends, cluster_ids
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号