def _init_decoder(self):
"""
Creates decoder attributes.
We cannot simply use a dynamic_rnn since we are feeding the outputs of the
decoder back into the inputs.
Therefore we use a raw_rnn and emulate a dynamic_rnn with this behavior.
(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn.py)
"""
# EOS token added
self.decoder_inputs_length = self.encoder_inputs_length + 1
def loop_fn_initial(time, cell_output, cell_state, loop_state):
elements_finished = (time >= self.decoder_inputs_length)
# EOS token (0 + self.EOS)
initial_input = tf.zeros([self.batch_size, self.decoder_cell.output_size], dtype=tf.float32) + self.EOS
initial_cell_state = self.encoder_final_state
initial_loop_state = None # we don't need to pass any additional information
return (elements_finished,
initial_input,
initial_cell_state,
None, # cell output is dummy here
initial_loop_state)
def loop_fn(time, cell_output, cell_state, loop_state):
if cell_output is None: # time == 0
return loop_fn_initial(time, cell_output, cell_state, loop_state)
cell_output.set_shape([self.batch_size, self.decoder_cell.output_size])
emit_output = cell_output
next_cell_state = cell_state
elements_finished = (time >= self.decoder_inputs_length)
finished = tf.reduce_all(elements_finished)
next_input = tf.cond(
finished,
lambda: tf.zeros([self.batch_size, self.decoder_cell.output_size], dtype=tf.float32), # self.PAD
lambda: cell_output # Use the input from the previous cell
)
next_loop_state = None
return (
elements_finished,
next_input,
next_cell_state,
emit_output,
next_loop_state
)
decoder_outputs_ta, decoder_final_state, _ = tf.nn.raw_rnn(self.decoder_cell, loop_fn)
self.decoder_outputs = decoder_outputs_ta.stack()
self.decoder_outputs = tf.transpose(self.decoder_outputs, [1, 0, 2])
with tf.variable_scope('DecoderOutputProjection') as scope:
self.decoder_outputs = self.projection(self.decoder_outputs, self.seq_width, scope)
评论列表
文章目录