model.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:answer-triggering 作者: jiez-osu 项目源码 文件源码
def embed_tokens(self, is_training, config, embedding_initializer):
    """Embedds input tokens.
    """
    vocab_size = config.vocab_size
    size = config.word_embed_size
    max_question_length = self.max_question_length
    max_sentence_length = self.max_sentence_length
    max_sentence_num = self.max_sentence_num

    with tf.variable_scope("embed"):
      with tf.device("/cpu:0"):
        embedding = tf.get_variable(
            "embedding_mat", [vocab_size, size],
            initializer=embedding_initializer,
            dtype=config.data_type,
            trainable=False # Continue to train pretrained word2vec
            # trainable=True # Continue to train pretrained word2vec
            )

        self._embedding = embedding
        embed_question= []
        for i in xrange(max_question_length):
          embed_question.append(
              tf.nn.embedding_lookup(embedding, self._input_question[i]))
          if is_training and config.w_embed_keep_prob < 1:
            embed_question[i] = tf.nn.dropout(embed_question[i],
                                              config.w_embed_keep_prob)
          if NUMERIC_CHECK:
            embed_question[i] = \
                tf.check_numerics(embed_question[i],
                    "embed_question[{}][{}] numeric error".format(i))

        embed_sentences = []
        for i in xrange(max_sentence_num):
          embed_sentences.append([])
          for j in xrange(max_sentence_length):
            embed_sentences[i].append(
              tf.nn.embedding_lookup(embedding, self._input_sentences[i][j]))
            if is_training and config.w_embed_keep_prob < 1:
              embed_sentences[i][j] = tf.nn.dropout(embed_sentences[i][j],
                                                    config.w_embed_keep_prob)
            if NUMERIC_CHECK:
              embed_sentences[i][j] = \
                  tf.check_numerics(embed_sentences[i][j],
                      "embed_sentences[{}][{}] numeric error".format(i, j))

    return embed_question, embed_sentences


  # RESULTS
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号