def __init__(self,env, is_batch_norm):
self.env = env
self.num_states = env.observation_space.shape[0]
self.num_actions = env.action_space.shape[0]
if is_batch_norm:
self.critic_net = CriticNet_bn(self.num_states, self.num_actions)
self.actor_net = ActorNet_bn(self.num_states, self.num_actions)
else:
self.critic_net = CriticNet(self.num_states, self.num_actions)
self.actor_net = ActorNet(self.num_states, self.num_actions)
#Initialize Buffer Network:
self.replay_memory = deque()
#Intialize time step:
self.time_step = 0
self.counter = 0
action_max = np.array(env.action_space.high).tolist()
action_min = np.array(env.action_space.low).tolist()
action_bounds = [action_max,action_min]
self.grad_inv = grad_inverter(action_bounds)
评论列表
文章目录