dqn.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:DQN 作者: Ivehui 项目源码 文件源码
def train(self, tran, selected):
        self.targetNet.blobs['frames'].data[...] \
            = tran.frames[selected + 1].copy()
        netOut = self.targetNet.forward()

        target = np.tile(tran.reward[selected]
                         + pms.discount
                         * tran.n_last[selected]
                         * np.resize(netOut['value_q'].max(1),
                                     (pms.batchSize, 1)),
                         (pms.actionSize,)
                         ) * tran.action[selected]

        self.solver.net.blobs['target'].data[...] = target
        self.solver.net.blobs['frames'].data[...] = tran.frames[selected].copy()
        self.solver.net.blobs['filter'].data[...] = tran.action[selected].copy()
        self.solver.step(1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号