def build_graph(self):
"""
builds the computational graph that performs a step-by-step evaluation
of the input data batches
"""
self.unpacked_input_data = utility.unpack_into_tensorarray(self.input_data, 1, self.sequence_length)
outputs = tf.TensorArray(tf.float32, self.sequence_length)
read_weightings = tf.TensorArray(tf.float32, self.sequence_length)
write_weightings = tf.TensorArray(tf.float32, self.sequence_length)
write_vectors = tf.TensorArray(tf.float32, self.sequence_length)
key_vectors = tf.TensorArray(tf.float32, self.sequence_length)
beta_vectors = tf.TensorArray(tf.float32, self.sequence_length)
shift_vectors = tf.TensorArray(tf.float32, self.sequence_length)
gamma_vectors = tf.TensorArray(tf.float32, self.sequence_length)
gates_vectors = tf.TensorArray(tf.float32, self.sequence_length)
memory_vectors = tf.TensorArray(tf.float32, self.sequence_length)
controller_state = self.controller.get_state() if self.controller.has_recurrent_nn else (tf.zeros(1), tf.zeros(1))
if not isinstance(controller_state, LSTMStateTuple):
controller_state = LSTMStateTuple(controller_state[0], controller_state[1])
memory_state = self.memory.init_memory()
final_results = None
with tf.variable_scope("Sequence_Loop") as scope:
time = tf.constant(0, dtype=tf.int32)
final_results = tf.while_loop(
cond=lambda time, *_: time < self.sequence_length,
body=self._loop_body,
loop_vars=(
time, memory_state, outputs,
read_weightings, write_weightings, controller_state, write_vectors,
key_vectors, beta_vectors, shift_vectors, gamma_vectors,
gates_vectors, memory_vectors
),
parallel_iterations=32,
swap_memory=True
)
dependencies = []
if self.controller.has_recurrent_nn:
dependencies.append(self.controller.update_state(final_results[5]))
with tf.control_dependencies(dependencies):
self.packed_output = utility.pack_into_tensor(final_results[2], axis=1)
# packed_memory_view and its content is just for debugging purposes.
self.packed_memory_view = {
'read_weightings': utility.pack_into_tensor(final_results[3], axis=1),
'write_weightings': utility.pack_into_tensor(final_results[4], axis=1),
'write_vectors': utility.pack_into_tensor(final_results[6], axis=1),
'key_vectors': utility.pack_into_tensor(final_results[7], axis=1),
'beta_vectors': utility.pack_into_tensor(final_results[8], axis=1),
'shift_vectors': utility.pack_into_tensor(final_results[9], axis=1),
'gamma_vectors': utility.pack_into_tensor(final_results[10], axis=1),
'gates_vectors': utility.pack_into_tensor(final_results[11], axis=1),
'memory_vectors': utility.pack_into_tensor(final_results[12], axis=1)
}
评论列表
文章目录