reinforce.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def accumulate_grad(self):
        if self.n_backward == 0:
            self.model.zerograds()
        # Compute losses
        losses = []
        for r_seq, log_prob_seq, ent_seq in zip(self.reward_sequences,
                                                self.log_prob_sequences,
                                                self.entropy_sequences):
            assert len(r_seq) - 1 == len(log_prob_seq) == len(ent_seq)
            # Convert rewards into returns (=sum of future rewards)
            R_seq = np.cumsum(list(reversed(r_seq[1:])))[::-1]
            for R, log_prob, entropy in zip(R_seq, log_prob_seq, ent_seq):
                loss = -R * log_prob - self.beta * entropy
                losses.append(loss)
        total_loss = chainerrl.functions.sum_arrays(losses)
        # When self.batchsize is future.types.newint.newint, dividing a
        # Variable with it will raise an error, so it is manually converted to
        # float here.
        total_loss /= float(self.batchsize)
        total_loss.backward()
        self.reward_sequences = [[]]
        self.log_prob_sequences = [[]]
        self.entropy_sequences = [[]]
        self.n_backward += 1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号