ntm.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Neural-Turing-Machine 作者: camigord 项目源码 文件源码
def build_graph(self):
        """
        builds the computational graph that performs a step-by-step evaluation
        of the input data batches
        """
        self.unpacked_input_data = utility.unpack_into_tensorarray(self.input_data, 1, self.sequence_length)

        outputs = tf.TensorArray(tf.float32, self.sequence_length)
        read_weightings = tf.TensorArray(tf.float32, self.sequence_length)
        write_weightings = tf.TensorArray(tf.float32, self.sequence_length)
        write_vectors = tf.TensorArray(tf.float32, self.sequence_length)
        key_vectors = tf.TensorArray(tf.float32, self.sequence_length)
        beta_vectors = tf.TensorArray(tf.float32, self.sequence_length)
        shift_vectors = tf.TensorArray(tf.float32, self.sequence_length)
        gamma_vectors = tf.TensorArray(tf.float32, self.sequence_length)
        gates_vectors = tf.TensorArray(tf.float32, self.sequence_length)
        memory_vectors = tf.TensorArray(tf.float32, self.sequence_length)

        controller_state = self.controller.get_state() if self.controller.has_recurrent_nn else (tf.zeros(1), tf.zeros(1))
        if not isinstance(controller_state, LSTMStateTuple):
            controller_state = LSTMStateTuple(controller_state[0], controller_state[1])

        memory_state = self.memory.init_memory()
        final_results = None

        with tf.variable_scope("Sequence_Loop") as scope:
            time = tf.constant(0, dtype=tf.int32)

            final_results = tf.while_loop(
                cond=lambda time, *_: time < self.sequence_length,
                body=self._loop_body,
                loop_vars=(
                    time, memory_state, outputs,
                    read_weightings, write_weightings, controller_state, write_vectors,
                    key_vectors, beta_vectors, shift_vectors, gamma_vectors,
                    gates_vectors, memory_vectors
                ),
                parallel_iterations=32,
                swap_memory=True
            )


        dependencies = []
        if self.controller.has_recurrent_nn:
            dependencies.append(self.controller.update_state(final_results[5]))

        with tf.control_dependencies(dependencies):
            self.packed_output = utility.pack_into_tensor(final_results[2], axis=1)
            # packed_memory_view and its content is just for debugging purposes.
            self.packed_memory_view = {
                'read_weightings': utility.pack_into_tensor(final_results[3], axis=1),
                'write_weightings': utility.pack_into_tensor(final_results[4], axis=1),
                'write_vectors': utility.pack_into_tensor(final_results[6], axis=1),
                'key_vectors': utility.pack_into_tensor(final_results[7], axis=1),
                'beta_vectors': utility.pack_into_tensor(final_results[8], axis=1),
                'shift_vectors': utility.pack_into_tensor(final_results[9], axis=1),
                'gamma_vectors': utility.pack_into_tensor(final_results[10], axis=1),
                'gates_vectors': utility.pack_into_tensor(final_results[11], axis=1),
                'memory_vectors': utility.pack_into_tensor(final_results[12], axis=1)
            }
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号