dqn.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:pytorch.rl.learning 作者: moskomule 项目源码 文件源码
def _train_nn(self):
        # neural network part
        self.optimizer.zero_grad()
        batch_state_before, batch_action, batch_reward, batch_state_after, batch_done = self.get_batch()
        target = self.agent.estimate_value(batch_reward, batch_state_after, batch_done)
        q_value = self.agent.q_value(batch_state_before, batch_action)
        loss = self.agent.net.loss(q_value, target)
        if self._step % self.gradient_update_freq == 0:
            loss.backward()
            self.optimizer.step()

        if self._step % self.log_freq_by_step == 0:
            self._writer.add_scalar("epsilon", self.agent.epsilon, self._step)
            self._writer.add_scalar("q_net-target", (q_value.data - target.data).mean(), self._step)
            self._writer.add_scalar("loss", loss.data.cpu()[0], self._step)

        return loss.data[0]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号