def _rnn_attention_decoder(self, decoder_cell, training_wheels):
loop_fn = self._custom_rnn_loop_fn(decoder_cell.output_size, training_wheels=training_wheels)
decoder_outputs, _, (context_vectors_array, attention_logits_array, pointer_probability_array) = \
tf.nn.raw_rnn(decoder_cell,
loop_fn,
swap_memory=True)
decoder_outputs = decoder_outputs.stack()
decoder_outputs = tf.transpose(decoder_outputs, [1, 0, 2])
attention_logits = attention_logits_array.gather(tf.range(0, attention_logits_array.size() - 1))
attention_logits = tf.transpose(attention_logits, [1, 0, 2])
context_vectors = context_vectors_array.gather(tf.range(0, context_vectors_array.size() - 1))
context_vectors = tf.transpose(context_vectors, [1, 0, 2])
pointer_probabilities = pointer_probability_array.gather(tf.range(0, pointer_probability_array.size() - 1))
pointer_probabilities = tf.transpose(pointer_probabilities, [1, 0])
return decoder_outputs, context_vectors, attention_logits, pointer_probabilities
评论列表
文章目录