def train(self, tran, selected):
self.targetNet.blobs['frames'].data[...] \
= tran.frames[selected + 1].copy()
netOut = self.targetNet.forward()
target = np.tile(tran.reward[selected]
+ pms.discount
* tran.n_last[selected]
* np.resize(netOut['value_q'].max(1),
(pms.batchSize, 1)),
(pms.actionSize,)
) * tran.action[selected]
self.solver.net.blobs['target'].data[...] = target
self.solver.net.blobs['frames'].data[...] = tran.frames[selected].copy()
self.solver.net.blobs['filter'].data[...] = tran.action[selected].copy()
self.solver.step(1)
评论列表
文章目录