pcl.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def compute_loss(self, t_start, t_stop, rewards, values,
                     next_values, log_probs):

        seq_len = t_stop - t_start
        assert len(rewards) == seq_len
        assert len(values) == seq_len
        assert len(next_values) == seq_len
        assert len(log_probs) == seq_len

        pi_losses = []
        v_losses = []
        for t in range(t_start, t_stop):
            d = min(t_stop - t, self.rollout_len)
            # Discounted sum of immediate rewards
            R_seq = sum(self.gamma ** i * rewards[t + i] for i in range(d))
            # Discounted sum of log likelihoods
            G = chainerrl.functions.weighted_sum_arrays(
                xs=[log_probs[t + i] for i in range(d)],
                weights=[self.gamma ** i for i in range(d)])
            G = F.expand_dims(G, -1)
            last_v = next_values[t + d - 1]
            if not self.backprop_future_values:
                last_v = chainer.Variable(last_v.data)

            # C_pi only backprop through pi
            C_pi = (- values[t].data +
                    self.gamma ** d * last_v.data +
                    R_seq -
                    self.tau * G)

            # C_v only backprop through v
            C_v = (- values[t] +
                   self.gamma ** d * last_v +
                   R_seq -
                   self.tau * G.data)

            pi_losses.append(C_pi ** 2)
            v_losses.append(C_v ** 2)

        pi_loss = chainerrl.functions.sum_arrays(pi_losses) / 2
        v_loss = chainerrl.functions.sum_arrays(v_losses) / 2

        # Re-scale pi loss so that it is independent from tau
        pi_loss /= self.tau

        pi_loss *= self.pi_loss_coef
        v_loss *= self.v_loss_coef

        if self.normalize_loss_by_steps:
            pi_loss /= seq_len
            v_loss /= seq_len

        if self.process_idx == 0:
            self.logger.debug('pi_loss:%s v_loss:%s',
                              pi_loss.data, v_loss.data)

        return pi_loss + F.reshape(v_loss, pi_loss.data.shape)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号