model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:self-driving-cars 作者: musyoku 项目源码 文件源码
def eps_greedy(self, state_batch, exploration_rate):
        if state_batch.ndim == 1:
            state_batch = state_batch.reshape(1, -1)
        elif state_batch.ndim == 3:
            state_batch = state_batch.reshape(-1, 34 * config.rl_history_length)
        prop = np.random.uniform()
        if prop < exploration_rate:
            action_batch = np.random.randint(0, len(config.actions), (state_batch.shape[0],))
            q = None
        else:
            state_batch = Variable(state_batch)
            if config.use_gpu:
                state_batch.to_gpu()
            q = self.compute_q_variable(state_batch, test=True)
            if config.use_gpu:
                q.to_cpu()
            q = q.data
            action_batch = np.argmax(q, axis=1)
        for i in xrange(action_batch.shape[0]):
            action_batch[i] = self.get_action_for_index(action_batch[i])
        return action_batch, q
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号