model.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:Relation-Network 作者: juung 项目源码 文件源码
def questionLSTM(self, q, q_real_len, reuse = False, scope= "questionLSTM"):
        """
        Args
            q: zero padded qeustions, shape=[batch_size, q_max_len]
            q_real_len: original question length, shape = [batch_size, 1]

        Returns
            embedded_q: embedded questions, shape = [batch_size, q_hidden(32)]
        """
        embedded_q_word = tf.nn.embedding_lookup(self.q_word_embed_matrix, q)
        q_input = tf.unstack(embedded_q_word, num = self.q_max_len, axis=1)
        lstm_cell = rnn.BasicLSTMCell(self.q_hidden, reuse = reuse)
        outputs, _ = rnn.static_rnn(lstm_cell, q_input, dtype = tf.float32, scope = scope)

        outputs = tf.stack(outputs)
        outputs = tf.transpose(outputs, [1,0,2])
        index = tf.range(0, self.batch_size) * (self.q_max_len) + (q_real_len - 1)
        outputs = tf.gather(tf.reshape(outputs, [-1, self.s_hidden]), index)
        return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号