def _gru_encoder(cell, inputs, sequence_length, initial_state, dtype=None):
# Assume that the underlying cell is GRUCell-like
output_size = cell.output_size
dtype = dtype or inputs.dtype
batch = tf.shape(inputs)[0]
time_steps = tf.shape(inputs)[1]
zero_output = tf.zeros([batch, output_size], dtype)
if initial_state is None:
initial_state = cell.zero_state(batch, dtype)
input_ta = tf.TensorArray(dtype, time_steps,
tensor_array_name="input_array")
output_ta = tf.TensorArray(dtype, time_steps,
tensor_array_name="output_array")
input_ta = input_ta.unstack(tf.transpose(inputs, [1, 0, 2]))
def loop_func(t, out_ta, state):
inp_t = input_ta.read(t)
cell_output, new_state = cell(inp_t, state)
cell_output = _copy_through(t, sequence_length, zero_output,
cell_output)
new_state = _copy_through(t, sequence_length, state, new_state)
out_ta = out_ta.write(t, cell_output)
return t + 1, out_ta, new_state
time = tf.constant(0, dtype=tf.int32, name="time")
loop_vars = (time, output_ta, initial_state)
outputs = tf.while_loop(lambda t, *_: t < time_steps, loop_func,
loop_vars, parallel_iterations=32,
swap_memory=True)
output_final_ta = outputs[1]
final_state = outputs[2]
all_output = output_final_ta.stack()
all_output.set_shape([None, None, output_size])
all_output = tf.transpose(all_output, [1, 0, 2])
return all_output, final_state
评论列表
文章目录