def __init__(self,worker_idx,params,net,session,eval_var,worker_summary_dict):
self.dead = False
self.params = params
self.idx = worker_idx
#environment
if self.params['rom'] == 'toy_way':self.env = env_way.env_way(self.params)
else : self.env=env_atari.env_atari(self.params)
self.img = self.env.reset()
#build networks
self.train = net['train_ops'][self.idx]
self.net = net['worker_nets'][self.idx]
self.sess = session
self.worker_copy = net['copy_ops'][self.idx]
self.master=net['master_net']
self.global_frame = net['global_frame']
self.frame_ph = net['global_frame_ph']
self.gf_op = net['global_frame_op']
self.lr_ph = net['lr_ph']
self.summary_op = worker_summary_dict['op']
self.summary_writer = worker_summary_dict['writer']
self.eval_var = eval_var
if self.params['net_type'] == 'AnDQN' :
self.target = net['target_net']
eps_type = np.random.choice(np.arange(len(self.params['eps_prob'])),size=1,replace=True,p=np.array(self.params['eps_prob']))[-1]
self.eps_max = self.params['eps_max'][eps_type]
self.eps_min = self.params['eps_min'][eps_type]
self.eps_frame = self.params['eps_frame'][eps_type]
else : self.target = net['worker_nets'][self.idx] #In A3C, the target network is local network (for code sharing with DQN)
if self.idx == 0 and self.params['show_0th_thread'] :
cv2.startWindowThread()
cv2.namedWindow('Worker'+str(self.idx)+'_screen')
评论列表
文章目录