q-network.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:CartPole-v0 作者: hmtai6 项目源码 文件源码
def reply(self):        
        batch = self.memory.sample(nbReplay)

        states = np.array([ o[0] for o in batch ])
        states_ = np.array([ (nbReplay if o[3] is None else o[3]) for o in batch ])

        p = agent.brain.predict(states)
        p_ = agent.brain.predict(states_)

        x = np.zeros((nbReplay, self.stateCnt))
        y = np.zeros((nbReplay, self.actionCnt))     

        for i in range(nbReplay):
            o = batch[i]
            s = o[0]; a = o[1]; r = o[2]; s_ = o[3]

            t = p[i]
            if s_ is None:
                t[a] = r
            else:
                t[a] = r + td_discount_rate * numpy.amax(p_[i])

            x[i] = s
            y[i] = t
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号