acer_single_process.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pytorch-rl 作者: jingweiz 项目源码 文件源码
def _1st_order_trpo(self, detached_policy_loss_vb, detached_policy_vb, detached_avg_policy_vb, detached_splitted_policy_vb=None):
        on_policy = detached_splitted_policy_vb is None
        # KL divergence k = \delta_{\phi_{\theta}} DKL[ \pi(|\phi_{\theta_a}) || \pi{|\phi_{\theta}}]
        # kl_div_vb = F.kl_div(detached_policy_vb.log(), detached_avg_policy_vb, size_average=False) # NOTE: the built-in one does not work on batch
        kl_div_vb = categorical_kl_div(detached_policy_vb, detached_avg_policy_vb)
        # NOTE: k & g are wll w.r.t. the network output, which is detached_policy_vb
        # NOTE: gradient from this part will not flow back into the model
        # NOTE: that's why we are only using detached policy variables here
        if on_policy:
            k_vb = grad(outputs=kl_div_vb,               inputs=detached_policy_vb, retain_graph=False, only_inputs=True)[0]
            g_vb = grad(outputs=detached_policy_loss_vb, inputs=detached_policy_vb, retain_graph=False, only_inputs=True)[0]
        else:
            # NOTE NOTE NOTE !!!
            # NOTE: here is why we cannot simply detach then split the policy_vb, but must split before detach
            # NOTE: cos if we do that then the split cannot backtrace the grads computed in this later part of the graph
            # NOTE: it would have no way to connect to the graphs in the model
            k_vb = grad(outputs=(kl_div_vb.split(1, 0)),               inputs=(detached_splitted_policy_vb), retain_graph=False, only_inputs=True)
            g_vb = grad(outputs=(detached_policy_loss_vb.split(1, 0)), inputs=(detached_splitted_policy_vb), retain_graph=False, only_inputs=True)
            k_vb = torch.cat(k_vb, 0)
            g_vb = torch.cat(g_vb, 0)

        kg_dot_vb = (k_vb * g_vb).sum(1, keepdim=True)
        kk_dot_vb = (k_vb * k_vb).sum(1, keepdim=True)
        z_star_vb = g_vb - ((kg_dot_vb - self.master.clip_1st_order_trpo) / kk_dot_vb).clamp(min=0) * k_vb

        return z_star_vb
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号