a3c.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:deep_rl_vizdoom 作者: mihahauke 项目源码 文件源码
def create_architecture(self):
        self.vars.sequence_length = tf.placeholder(tf.int64, [1], name="sequence_length")

        fc_input = self.get_input_layers()

        fc1 = fully_connected(fc_input, num_outputs=self.fc_units_num,
                              scope=self._name_scope + "/fc1")

        fc1_reshaped = tf.reshape(fc1, [1, -1, self.fc_units_num])
        self.recurrent_cells = self.ru_class(self._recurrent_units_num)
        state_c = tf.placeholder(tf.float32, [1, self.recurrent_cells.state_size.c], name="initial_lstm_state_c")
        state_h = tf.placeholder(tf.float32, [1, self.recurrent_cells.state_size.h], name="initial_lstm_state_h")
        self.vars.initial_network_state = LSTMStateTuple(state_c, state_h)
        rnn_outputs, self.ops.network_state = tf.nn.dynamic_rnn(self.recurrent_cells,
                                                                fc1_reshaped,
                                                                initial_state=self.vars.initial_network_state,
                                                                sequence_length=self.vars.sequence_length,
                                                                time_major=False,
                                                                scope=self._name_scope)
        reshaped_rnn_outputs = tf.reshape(rnn_outputs, [-1, self._recurrent_units_num])

        self.reset_state()

        self.ops.pi = fully_connected(reshaped_rnn_outputs,
                                      num_outputs=self.actions_num,
                                      scope=self._name_scope + "/fc_pi",
                                      activation_fn=tf.nn.softmax)

        state_value = linear(reshaped_rnn_outputs,
                             num_outputs=1,
                             scope=self._name_scope + "/fc_value")

        self.ops.v = tf.reshape(state_value, [-1])

        if self.multi_frameskip:
            frameskip_output_len = self.actions_num
        else:
            frameskip_output_len = 1

        if self.fs_stop_gradient:
            reshaped_rnn_outputs = tf.stop_gradient(reshaped_rnn_outputs)
        self.ops.frameskip_mu = 1 + fully_connected(reshaped_rnn_outputs,
                                                    num_outputs=frameskip_output_len,
                                                    scope=self._name_scope + "/fc_frameskip_mu",
                                                    activation_fn=tf.nn.relu,
                                                    biases_initializer=tf.constant_initializer(self.fs_mu_bias))

        self.ops.frameskip_variance = fully_connected(reshaped_rnn_outputs,
                                                      num_outputs=frameskip_output_len,
                                                      scope=self._name_scope + "/fc_frameskip_variance",
                                                      activation_fn=tf.nn.relu,
                                                      biases_initializer=tf.constant_initializer(
                                                          self.fs_sigma_bias))

        if not self.multi_frameskip:
            self.ops.frameskip_mu = tf.reshape(self.ops.frameskip_mu, (-1,))
            self.ops.frameskip_variance = tf.reshape(self.ops.frameskip_variance, (-1,))

        self.ops.frameskip_sigma = tf.sqrt(self.ops.frameskip_variance, name="frameskip_sigma")
        self.ops.frameskip_policy = [self.ops.frameskip_mu, self.ops.frameskip_sigma]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号