models.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:deeppavlov 作者: deepmipt 项目源码 文件源码
def encode_sentences(self, text_emb, text_len, text_len_mask):
        """
        Passes the input tensor through bi_LSTM.
        Args:
            text_emb: [num_sentences, max_sentence_length, emb], text code in tensor
            text_len: tf.int32, [Amount of sentences]
            text_len_mask: boolean mask for text_emb

        Returns: [num_sentences, max_sentence_length, emb], output of bi-LSTM after boolean mask application

        """
        num_sentences = tf.shape(text_emb)[0]
        max_sentence_length = tf.shape(text_emb)[1]

        # Transpose before and after for efficiency.
        inputs = tf.transpose(text_emb, [1, 0, 2])  # [max_sentence_length, num_sentences, emb]

        with tf.variable_scope("fw_cell"):
            cell_fw = utils.CustomLSTMCell(self.opt["lstm_size"], num_sentences, self.dropout)
            preprocessed_inputs_fw = cell_fw.preprocess_input(inputs)
        with tf.variable_scope("bw_cell"):
            cell_bw = utils.CustomLSTMCell(self.opt["lstm_size"], num_sentences, self.dropout)
            preprocessed_inputs_bw = cell_bw.preprocess_input(inputs)
            preprocessed_inputs_bw = tf.reverse_sequence(preprocessed_inputs_bw,
                                                         seq_lengths=text_len,
                                                         seq_dim=0,
                                                         batch_dim=1)
        state_fw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_fw.initial_state.c, [num_sentences, 1]),
                                                 tf.tile(cell_fw.initial_state.h, [num_sentences, 1]))
        state_bw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_bw.initial_state.c, [num_sentences, 1]),
                                                 tf.tile(cell_bw.initial_state.h, [num_sentences, 1]))
        with tf.variable_scope("lstm"):
            with tf.variable_scope("fw_lstm"):
                fw_outputs, fw_states = tf.nn.dynamic_rnn(cell=cell_fw,
                                                          inputs=preprocessed_inputs_fw,
                                                          sequence_length=text_len,
                                                          initial_state=state_fw,
                                                          time_major=True)
            with tf.variable_scope("bw_lstm"):
                bw_outputs, bw_states = tf.nn.dynamic_rnn(cell=cell_bw,
                                                          inputs=preprocessed_inputs_bw,
                                                          sequence_length=text_len,
                                                          initial_state=state_bw,
                                                          time_major=True)

        bw_outputs = tf.reverse_sequence(bw_outputs,
                                         seq_lengths=text_len,
                                         seq_dim=0,
                                         batch_dim=1)

        text_outputs = tf.concat([fw_outputs, bw_outputs], 2)
        text_outputs = tf.transpose(text_outputs, [1, 0, 2])  # [num_sentences, max_sentence_length, emb]
        return self.flatten_emb_by_sentence(text_outputs, text_len_mask)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号