def _call_helper(self):
time = tf.constant(0, dtype=tf.int32)
inp = self._decoder.init_input()
state = self._decoder.init_state()
finished = tf.tile([False], [utils.get_dimension(inp, 0)])
output_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
loop_vars = [time, inp, state, finished, output_ta]
results = tf.while_loop(
cond=self.cond, body=self.body, loop_vars=loop_vars,
parallel_iterations=self._parallel_iterations,
swap_memory=self._swap_memory)
output_ta = results[-1]
output = output_ta.stack()
output = tf.transpose(output, [1, 0, 2])
state = results[2]
return output, state
评论列表
文章目录