model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Relation-Network 作者: juung 项目源码 文件源码
def contextLSTM(self, c, l, c_real_len, reuse = False, scope = "ContextLSTM"):

        def sentenceLSTM(s,
                         s_real_len,
                         reuse = reuse,
                         scope = "sentenceLSTM"):
            """
            embedding sentence

            Arguments
                s: sentence (word index list), shape = [batch_size*20, 12]
                s_real_len: length of the sentence before zero padding, int32

            Returns
                embedded_s: embedded sentence, shape = [batch_size*20, 32]
            """
            embedded_sentence_word = tf.nn.embedding_lookup(self.c_word_embed_matrix, s)
            s_input = tf.unstack(embedded_sentence_word, num = self.s_max_len, axis = 1)
            lstm_cell = rnn.BasicLSTMCell(self.s_hidden, reuse = reuse)
            outputs, _ = rnn.static_rnn(lstm_cell, s_input, dtype = tf.float32, scope = scope)

            outputs = tf.stack(outputs)
            outputs = tf.transpose(outputs, [1,0,2])
            index = tf.range(0, self.batch_size* self.c_max_len) * (self.s_max_len) + (s_real_len - 1)
            outputs = tf.gather(tf.reshape(outputs, [-1, self.s_hidden]), index)
            return outputs

        """
        Args
            c: list of sentences, shape = [batch_size, 20, 12]
            l: list of labels, shape = [batch_size, 20, 20]
            c_real_len: list of real length, shape = [batch_size, 20]

        Returns
            tagged_c_objects: list of embedded sentence + label, shape = [batch_size, 52] 20?
            len(tagged_c_objects) = 20
        """
        sentences = tf.reshape(c, shape = [-1, self.s_max_len])
        real_lens = tf.reshape(c_real_len, shape= [-1])
        labels = tf.reshape(l, shape = [-1, self.c_max_len])

        s_embedded = sentenceLSTM(sentences, real_lens, reuse = reuse)
        c_embedded = tf.concat([s_embedded, labels], axis=1)
        c_embedded = tf.reshape(c_embedded, shape = [self.batch_size, self.c_max_len, self.c_max_len + self.c_word_embed])
        tagged_c_objects = tf.unstack(c_embedded, axis=1)
        return tagged_c_objects
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号