querysum_model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:querysum 作者: helmertz 项目源码 文件源码
def _custom_rnn_loop_fn(self, cell_size, training_wheels):
        def loop_fn(time, cell_output, cell_state, loop_state):
            if cell_output is None:  # time == 0
                context_vectors_array = tf.TensorArray(tf.float32, size=tf.shape(self.references_placeholder)[1] + 1)
                attention_logits_array = tf.TensorArray(tf.float32, size=tf.shape(self.references_placeholder)[1] + 1)
                pointer_probability_array = tf.TensorArray(tf.float32,
                                                           size=tf.shape(self.references_placeholder)[1] + 1)
                next_cell_state = self.final_encoder_state
                go_id = self.summary_vocabulary.word_to_id('<GO>')
                last_output_embedding = tf.nn.embedding_lookup(self.embeddings, tf.tile([go_id], [self.batch_size]))
            else:
                context_vectors_array, attention_logits_array, pointer_probability_array = loop_state
                next_cell_state = cell_state

                if training_wheels:
                    voc_indices = self.references_placeholder[:, time - 1]
                    pointer_indices = self.pointer_reference_placeholder[:, time - 1]
                    pointer_switch = tf.cast(self.pointer_switch_placeholder[:, time - 1], tf.bool)

                    batch_range = tf.range(self.batch_size)
                    pointer_indexer = tf.stack([batch_range, pointer_indices], axis=1)
                    attention_vocabulary_indices = tf.gather_nd(self.documents_placeholder, pointer_indexer)

                    mixed_indices = tf.where(pointer_switch, attention_vocabulary_indices, voc_indices)
                    last_output_embedding = tf.nn.embedding_lookup(self.embeddings, mixed_indices)
                else:
                    last_output_embedding = self._extract_argmax_and_embed(cell_output, cell_size,
                                                                           tf.shape(self.documents_placeholder)[0])
            context_vector, attention_logits = self._attention(next_cell_state, last_output_embedding)
            pointer_probabilities = self._pointer_probabilities(context_vector, next_cell_state, last_output_embedding)

            context_vectors_array = context_vectors_array.write(time, context_vector)
            attention_logits_array = attention_logits_array.write(time, attention_logits)
            pointer_probability_array = pointer_probability_array.write(time, pointer_probabilities)

            next_input = tf.concat([last_output_embedding, context_vector, self.query_last], axis=1)
            elements_finished = (time >= self.reference_lengths_placeholder)

            emit_output = cell_output
            next_loop_state = (context_vectors_array, attention_logits_array, pointer_probability_array)
            return elements_finished, next_input, next_cell_state, emit_output, next_loop_state

        return loop_fn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号