querysum_model.py 文件源码

python
阅读 42 收藏 0 点赞 0 评论 0

项目:querysum 作者: helmertz 项目源码 文件源码
def _rnn_attention_decoder(self, decoder_cell, training_wheels):
        loop_fn = self._custom_rnn_loop_fn(decoder_cell.output_size, training_wheels=training_wheels)
        decoder_outputs, _, (context_vectors_array, attention_logits_array, pointer_probability_array) = \
            tf.nn.raw_rnn(decoder_cell,
                          loop_fn,
                          swap_memory=True)

        decoder_outputs = decoder_outputs.stack()
        decoder_outputs = tf.transpose(decoder_outputs, [1, 0, 2])

        attention_logits = attention_logits_array.gather(tf.range(0, attention_logits_array.size() - 1))
        attention_logits = tf.transpose(attention_logits, [1, 0, 2])

        context_vectors = context_vectors_array.gather(tf.range(0, context_vectors_array.size() - 1))
        context_vectors = tf.transpose(context_vectors, [1, 0, 2])

        pointer_probabilities = pointer_probability_array.gather(tf.range(0, pointer_probability_array.size() - 1))
        pointer_probabilities = tf.transpose(pointer_probabilities, [1, 0])

        return decoder_outputs, context_vectors, attention_logits, pointer_probabilities
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号