def _custom_rnn_loop_fn(self, cell_size, training_wheels):
def loop_fn(time, cell_output, cell_state, loop_state):
if cell_output is None: # time == 0
context_vectors_array = tf.TensorArray(tf.float32, size=tf.shape(self.references_placeholder)[1] + 1)
attention_logits_array = tf.TensorArray(tf.float32, size=tf.shape(self.references_placeholder)[1] + 1)
pointer_probability_array = tf.TensorArray(tf.float32,
size=tf.shape(self.references_placeholder)[1] + 1)
next_cell_state = self.final_encoder_state
go_id = self.summary_vocabulary.word_to_id('<GO>')
last_output_embedding = tf.nn.embedding_lookup(self.embeddings, tf.tile([go_id], [self.batch_size]))
else:
context_vectors_array, attention_logits_array, pointer_probability_array = loop_state
next_cell_state = cell_state
if training_wheels:
voc_indices = self.references_placeholder[:, time - 1]
pointer_indices = self.pointer_reference_placeholder[:, time - 1]
pointer_switch = tf.cast(self.pointer_switch_placeholder[:, time - 1], tf.bool)
batch_range = tf.range(self.batch_size)
pointer_indexer = tf.stack([batch_range, pointer_indices], axis=1)
attention_vocabulary_indices = tf.gather_nd(self.documents_placeholder, pointer_indexer)
mixed_indices = tf.where(pointer_switch, attention_vocabulary_indices, voc_indices)
last_output_embedding = tf.nn.embedding_lookup(self.embeddings, mixed_indices)
else:
last_output_embedding = self._extract_argmax_and_embed(cell_output, cell_size,
tf.shape(self.documents_placeholder)[0])
context_vector, attention_logits = self._attention(next_cell_state, last_output_embedding)
pointer_probabilities = self._pointer_probabilities(context_vector, next_cell_state, last_output_embedding)
context_vectors_array = context_vectors_array.write(time, context_vector)
attention_logits_array = attention_logits_array.write(time, attention_logits)
pointer_probability_array = pointer_probability_array.write(time, pointer_probabilities)
next_input = tf.concat([last_output_embedding, context_vector, self.query_last], axis=1)
elements_finished = (time >= self.reference_lengths_placeholder)
emit_output = cell_output
next_loop_state = (context_vectors_array, attention_logits_array, pointer_probability_array)
return elements_finished, next_input, next_cell_state, emit_output, next_loop_state
return loop_fn
评论列表
文章目录