def __init__(self, policy, target_policy, cmdl):
self.name = "Categorical-PI"
self.cmdl = cmdl
self.policy = policy
self.target_policy = target_policy
self.lr = cmdl.lr
self.gamma = cmdl.gamma
self.optimizer = optim_factory(self.policy.parameters(), cmdl)
self.optimizer.zero_grad()
self.lr_generator = lr_schedule(cmdl.lr, 0.00001, cmdl.training_steps)
self.dtype = dtype = TorchTypes(cmdl.cuda)
self.v_min, self.v_max = v_min, v_max = cmdl.v_min, cmdl.v_max
self.atoms_no = atoms_no = cmdl.atoms_no
self.support = torch.linspace(v_min, v_max, atoms_no)
self.support = self.support.type(dtype.FT)
self.delta_z = (cmdl.v_max - cmdl.v_min) / (cmdl.atoms_no - 1)
self.m = torch.zeros(cmdl.batch_size, self.atoms_no).type(dtype.FT)
categorical_update.py 文件源码
python
阅读 40
收藏 0
点赞 0
评论 0
评论列表
文章目录