def __init__(self, agent: Agent, val_env: gym.Env, lr, memory_size, target_update_freq, gradient_update_freq,
batch_size, replay_start, val_freq, log_freq_by_step, log_freq_by_ep, val_epsilon,
log_dir, weight_dir):
"""
:param agent: agent object
:param val_env: environment for validation
:param lr: learning rate of optimizer
:param memory_size: size of replay memory
:param target_update_freq: frequency of update target network in steps
:param gradient_update_freq: frequency of q-network update in steps
:param batch_size: batch size for q-net
:param replay_start: number of random exploration before starting
:param val_freq: frequency of validation in steps
:param log_freq_by_step: frequency of logging in steps
:param log_freq_by_ep: frequency of logging in episodes
:param val_epsilon: exploration rate for validation
:param log_dir: directory for saving tensorboard things
:param weight_dir: directory for saving weights when validated
"""
self.agent = agent
self.env = self.agent.env
self.val_env = val_env
self.optimizer = optim.RMSprop(params=self.agent.net.parameters(), lr=lr)
self.memory = Memory(memory_size)
self.target_update_freq = target_update_freq
self.batch_size = batch_size
self.replay_start = replay_start
self.gradient_update_freq = gradient_update_freq
self._step = 0
self._episode = 0
self._warmed = False
self._val_freq = val_freq
self.log_freq_by_step = log_freq_by_step
self.log_freq_by_ep = log_freq_by_ep
self._val_epsilon = val_epsilon
self._writer = SummaryWriter(os.path.join(log_dir, datetime.now().strftime('%b%d_%H-%M-%S')))
if weight_dir is not None and not os.path.exists(weight_dir):
os.makedirs(weight_dir)
self.weight_dir = weight_dir
评论列表
文章目录