layer.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:2048-RL-DRQN 作者: Mostafa-Samir 项目源码 文件源码
def __call__(self, X):
        """
        Performs the LSTM's forget, input and output operations
        according to: http://arxiv.org/pdf/1402.1128v1.pdf without peepholes

        Parameters:
        ----------
        X: list[Tensor]
            The input list to process by the LSTM
        """
        outputs = tf.TensorArray(tf.float32, len(X))
        inputs = tf.TensorArray(tf.float32, len(X))
        t = tf.constant(0, dtype=tf.int32)

        for i, step_input in enumerate(X):
            inputs = inputs.write(i, step_input)

        def step_op(time, prev_state, prev_output, inputs_list, outputs_list):
            time_step = inputs_list.read(time)
            gates = tf.matmul(time_step, self.input_weights) + tf.matmul(prev_output, self.output_weights) + self.bias
            gates = tf.reshape(gates, [-1, self.num_hidden, 4])

            input_gate = tf.sigmoid(gates[:, :, 0])
            forget_gate = tf.sigmoid(gates[:, :, 1])
            candidate_state = tf.tanh(gates[:, :, 2])
            output_gate = tf.sigmoid(gates[:, :, 3])

            state = forget_gate * prev_state + input_gate * candidate_state
            output = output_gate * tf.tanh(state)
            new_outputs = outputs_list.write(time, output)

            return time + 1, state, output, inputs_list, new_outputs

        _, state, output, _, final_outputs = tf.while_loop(
            cond=lambda time, *_: time < len(X),
            body= step_op,
            loop_vars=(t, self.prev_state, self.prev_output, inputs, outputs),
            parallel_iterations=32,
            swap_memory=True
        )

        self.prev_state.assign(state)
        self.prev_output.assign(output)

        return [final_outputs.read(t) for t in range(len(X))]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号