approximators.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:dqn_vizdoom_theano 作者: mihahauke 项目源码 文件源码
def _compile(self, ddqn):

        a = self.inputs["A"]
        r = self.inputs["R"]
        nonterminal = self.inputs["Nonterminal"]

        q = ls.get_output(self.network, deterministic=True)

        if ddqn:
            q2 = ls.get_output(self.network, deterministic=True, inputs=self.alternate_input_mappings)
            q2_action_ref = tensor.argmax(q2, axis=1)

            q2_frozen = ls.get_output(self.frozen_network, deterministic=True)
            q2_max = q2_frozen[tensor.arange(q2_action_ref.shape[0]), q2_action_ref]
        else:
            q2_max = tensor.max(ls.get_output(self.frozen_network, deterministic=True), axis=1)

        target_q = r + self.gamma * nonterminal * q2_max
        predicted_q = q[tensor.arange(q.shape[0]), a]

        loss = self.build_loss_expression(predicted_q, target_q).sum()
        params = ls.get_all_params(self.network, trainable=True)

        # updates = lasagne.updates.rmsprop(loss, params, self._learning_rate, rho=0.95)
        updates = deepmind_rmsprop(loss, params, self.learning_rate)

        # TODO does FAST_RUN speed anything up?
        mode = None  # "FAST_RUN"

        s0_img = self.inputs["S0"]
        s1_img = self.inputs["S1"]

        if self.misc_state_included:
            s0_misc = self.inputs["S0_misc"]
            s1_misc = self.inputs["S1_misc"]
            print "Compiling the training function..."
            self._learn = theano.function([s0_img, s0_misc, s1_img, s1_misc, a, r, nonterminal], loss,
                                          updates=updates, mode=mode, name="learn_fn")
            print "Compiling the evaluation function..."
            self._evaluate = theano.function([s0_img, s0_misc], q, mode=mode,
                                             name="eval_fn")
        else:
            print "Compiling the training function..."
            self._learn = theano.function([s0_img, s1_img, a, r, nonterminal], loss, updates=updates, mode=mode,
                                          name="learn_fn")
            print "Compiling the evaluation function..."
            self._evaluate = theano.function([s0_img], q, mode=mode, name="eval_fn")
        print "Network compiled."
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号