model.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:nesgym 作者: codescv 项目源码 文件源码
def __init__(self,
                 image_shape,
                 num_actions,
                 frame_history_len=4,
                 replay_buffer_size=1000000,
                 training_freq=4,
                 training_starts=5000,
                 training_batch_size=32,
                 target_update_freq=1000,
                 reward_decay=0.99,
                 exploration=LinearSchedule(5000, 0.1),
                 log_dir="logs/"):
        """
            Double Deep Q Network
            params:
            image_shape: (height, width, n_values)
            num_actions: how many different actions we can choose
            frame_history_len: feed this number of frame data as input to the deep-q Network
            replay_buffer_size: size limit of replay buffer
            training_freq: train base q network once per training_freq steps
            training_starts: only train q network after this number of steps
            training_batch_size: batch size for training base q network with gradient descent
            reward_decay: decay factor(called gamma in paper) of rewards that happen in the future
            exploration: used to generate an exploration factor(see 'epsilon-greedy' in paper).
                         when rand(0,1) < epsilon, take random action; otherwise take greedy action.
            log_dir: path to write tensorboard logs
        """
        super().__init__()
        self.num_actions = num_actions
        self.training_freq = training_freq
        self.training_starts = training_starts
        self.training_batch_size = training_batch_size
        self.target_update_freq = target_update_freq
        self.reward_decay = reward_decay
        self.exploration = exploration

        # use multiple frames as input to q network
        input_shape = image_shape[:-1] + (image_shape[-1] * frame_history_len,)
        # used to choose action
        self.base_model = q_model(input_shape, num_actions)
        self.base_model.compile(optimizer=optimizers.adam(clipnorm=10, lr=1e-4, decay=1e-6, epsilon=1e-4), loss='mse')
        # used to estimate q values
        self.target_model = q_model(input_shape, num_actions)

        self.replay_buffer = ReplayBuffer(size=replay_buffer_size, frame_history_len=frame_history_len)
        # current replay buffer offset
        self.replay_buffer_idx = 0

        self.tensorboard_callback = TensorBoard(log_dir=log_dir)
        self.latest_losses = deque(maxlen=100)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号