dqn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch.rl.learning 作者: moskomule 项目源码 文件源码
def __init__(self, agent: Agent, val_env: gym.Env, lr, memory_size, target_update_freq, gradient_update_freq,
                 batch_size, replay_start, val_freq, log_freq_by_step, log_freq_by_ep, val_epsilon,
                 log_dir, weight_dir):
        """
        :param agent: agent object
        :param val_env: environment for validation
        :param lr: learning rate of optimizer
        :param memory_size: size of replay memory
        :param target_update_freq: frequency of update target network in steps
        :param gradient_update_freq: frequency of q-network update in steps
        :param batch_size: batch size for q-net
        :param replay_start: number of random exploration before starting
        :param val_freq: frequency of validation in steps
        :param log_freq_by_step: frequency of logging in steps
        :param log_freq_by_ep: frequency of logging in episodes
        :param val_epsilon: exploration rate for validation
        :param log_dir: directory for saving tensorboard things
        :param weight_dir: directory for saving weights when validated
        """
        self.agent = agent
        self.env = self.agent.env
        self.val_env = val_env
        self.optimizer = optim.RMSprop(params=self.agent.net.parameters(), lr=lr)
        self.memory = Memory(memory_size)
        self.target_update_freq = target_update_freq
        self.batch_size = batch_size
        self.replay_start = replay_start
        self.gradient_update_freq = gradient_update_freq
        self._step = 0
        self._episode = 0
        self._warmed = False
        self._val_freq = val_freq
        self.log_freq_by_step = log_freq_by_step
        self.log_freq_by_ep = log_freq_by_ep
        self._val_epsilon = val_epsilon
        self._writer = SummaryWriter(os.path.join(log_dir, datetime.now().strftime('%b%d_%H-%M-%S')))
        if weight_dir is not None and not os.path.exists(weight_dir):
            os.makedirs(weight_dir)
        self.weight_dir = weight_dir
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号