def decode(self, enc_outputs, enc_final_state):
with tf.variable_scope(self.decoder.scope):
def condition(time, all_outputs: tf.TensorArray, inputs, states):
def check_outputs_ends():
def has_end_word(t):
return tf.reduce_any(tf.equal(t, ANSWER_MAX))
output_label = tf.arg_max(all_outputs.stack(), 2)
output_label = tf.Print(output_label, [output_label], "Output Labels: ")
# The outputs are time-major, which means time is the first
# dimension. Here I need to check whether all the generated
# answers are ends with "</s>", so we need to transpose it
# to batch-major. Because `map_fn` only map function by the
# first dimension.
batch_major_outputs = tf.transpose(output_label, (1, 0))
all_outputs_ends = tf.reduce_all(tf.map_fn(has_end_word, batch_major_outputs, dtype=tf.bool))
return all_outputs_ends
# If the TensorArray has 0 size, stack() will trigger error,
# so I have to use condition function to check whether the
# size is 0.
all_ends = tf.cond(tf.equal(all_outputs.size(), 0),
lambda: tf.constant(False, tf.bool),
check_outputs_ends)
condition_result = tf.logical_and(tf.logical_not(all_ends), tf.less(time, ANSWER_MAX))
return condition_result
def body(time, all_outputs, inputs, state):
dec_outputs, dec_state, output_logits, next_input = self.decoder.step(inputs, state)
all_outputs = all_outputs.write(time, output_logits)
return time + 1, all_outputs, next_input, dec_state
output_ta = tensor_array_ops.TensorArray(dtype=tf.float32,
size=0,
dynamic_size=True,
element_shape=(None, config.DEC_VOCAB),
clear_after_read=False)
# with time-major data input, the batch size is the second dimension
batch_size = tf.shape(enc_outputs)[1]
zero_input = tf.ones(tf.expand_dims(batch_size, axis=0), dtype=tf.int32) * ANSWER_START
res = control_flow_ops.while_loop(
condition,
body,
loop_vars=[0, output_ta, self.decoder.zero_input(zero_input), enc_final_state],
)
final_outputs = res[1].stack()
final_outputs = tf.Print(final_outputs, [final_outputs], "Final Output: ")
final_state = res[3]
return final_outputs, final_state
评论列表
文章目录