reinforce.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:DeepRL 作者: arnomoonens 项目源码 文件源码
def build_network(self):
        self.rnn_state = None
        self.states = tf.placeholder(tf.float32, [None] + list(self.env.observation_space.shape), name="states")  # Observation
        # self.n_states = tf.placeholder(tf.float32, shape=[None], name="n_states")  # Observation
        self.a_n = tf.placeholder(tf.float32, name="a_n")  # Discrete action
        self.adv_n = tf.placeholder(tf.float32, name="adv_n")  # Advantage

        n_states = tf.shape(self.states)[:1]

        states = tf.expand_dims(flatten(self.states), [0])

        enc_cell = tf.contrib.rnn.GRUCell(self.config["n_hidden_units"])
        self.rnn_state_in = enc_cell.zero_state(1, tf.float32)
        L1, self.rnn_state_out = tf.nn.dynamic_rnn(cell=enc_cell,
                                                   inputs=states,
                                                   sequence_length=n_states,
                                                   initial_state=self.rnn_state_in,
                                                   dtype=tf.float32)
        self.probs = tf.contrib.layers.fully_connected(
            inputs=L1[0],
            num_outputs=self.env_runner.nA,
            activation_fn=tf.nn.softmax,
            weights_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.02),
            biases_initializer=tf.zeros_initializer())
        self.action = tf.squeeze(tf.multinomial(tf.log(self.probs), 1), name="action")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号