worker.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:A3C_tensorflow 作者: gliese581gg 项目源码 文件源码
def __init__(self,worker_idx,params,net,session,eval_var,worker_summary_dict):
        self.dead = False
        self.params = params
        self.idx = worker_idx
        #environment
        if self.params['rom'] == 'toy_way':self.env = env_way.env_way(self.params)
        else : self.env=env_atari.env_atari(self.params)
        self.img = self.env.reset()

        #build networks
        self.train = net['train_ops'][self.idx]
        self.net = net['worker_nets'][self.idx]
        self.sess = session
        self.worker_copy = net['copy_ops'][self.idx]
        self.master=net['master_net']
        self.global_frame = net['global_frame']
        self.frame_ph = net['global_frame_ph']
        self.gf_op = net['global_frame_op']
        self.lr_ph = net['lr_ph']
        self.summary_op = worker_summary_dict['op']
        self.summary_writer = worker_summary_dict['writer']
        self.eval_var = eval_var
        if self.params['net_type'] == 'AnDQN' : 
            self.target = net['target_net']
            eps_type = np.random.choice(np.arange(len(self.params['eps_prob'])),size=1,replace=True,p=np.array(self.params['eps_prob']))[-1]
            self.eps_max = self.params['eps_max'][eps_type]
            self.eps_min = self.params['eps_min'][eps_type]
            self.eps_frame = self.params['eps_frame'][eps_type]

        else : self.target = net['worker_nets'][self.idx] #In A3C, the target network is local network (for code sharing with DQN)

        if self.idx == 0 and self.params['show_0th_thread'] : 
            cv2.startWindowThread()
            cv2.namedWindow('Worker'+str(self.idx)+'_screen')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号