shared_utils.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:async-deep-rl 作者: traai 项目源码 文件源码
def __init__(self, num_actions, alg_type, opt_type = None, lr = 0):
        # Net
        if alg_type in ['q', 'sarsa']:
            self.var_shapes = [(8, 8, 4, 16), 
                                (16), 
                                (4, 4, 16, 32), 
                                (32), 
                                (2592, 256), #(3872, 256) if PADDING = "SAME" 
                                (256), 
                                (256, num_actions), 
                                (num_actions)]

            self.size = 0
            for shape in self.var_shapes:
                self.size += np.prod(shape)

            if opt_type == "adam":
                self.ms = self.malloc_contiguous(self.size)
                self.vs = self.malloc_contiguous(self.size)
                self.lr = RawValue(ctypes.c_float, lr)
            elif opt_type == "rmsprop":
                self.vars = self.malloc_contiguous(self.size, np.ones(self.size, dtype=np.float))
            else: #momentum
                self.vars = self.malloc_contiguous(self.size)

        else:
            # no lstm
            self.var_shapes = [(8, 8, 4, 16), 
                                (16), 
                                (4, 4, 16, 32), 
                                (32), 
                                (2592, 256), #(3872, 256) 
                                (256), 
                                (256, num_actions), 
                                (num_actions),
                                (256, 1),
                                (1)]

            self.size = 0
            for shape in self.var_shapes:
                self.size += np.prod(shape)

            if opt_type == "adam":
                self.ms = self.malloc_contiguous(self.size)
                self.vs = self.malloc_contiguous(self.size)
                self.lr = RawValue(ctypes.c_float, lr)
            if opt_type == "rmsprop":
                self.vars = self.malloc_contiguous(self.size, np.ones(self.size, dtype=np.float))
            elif opt_type == "momentum":
                self.vars = self.malloc_contiguous(self.size)
            else:
                self.vars = self.malloc_contiguous(self.size)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号