qlearning.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:malmo-challenge 作者: Microsoft 项目源码 文件源码
def train(self, x, y, actions=None):
        actions = actions.astype(np.int32)
        batch_size = len(actions)

        if self._gpu_device:
            x = cuda.to_gpu(x, self._gpu_device)
            y = cuda.to_gpu(y, self._gpu_device)
            actions = cuda.to_gpu(actions, self._gpu_device)

        q = self._model(x)
        q_subset = F.reshape(F.select_item(q, actions), (batch_size, 1))
        y = y.reshape(batch_size, 1)

        loss = F.sum(F.huber_loss(q_subset, y, 1.0))

        self._model.cleargrads()
        loss.backward()
        self._optimizer.update()

        self._loss_val = np.asscalar(cuda.to_cpu(loss.data))

        # Keeps track of the number of train() calls
        self._steps += 1
        if self._steps % self._target_update_interval == 0:
            # copy weights
            self._target.copyparams(self._model)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号