shared_memory.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tensorflow-rl 作者: steveKapturowski 项目源码 文件源码
def __init__(self, params, opt_type=None, lr=0, step=0):
        self.var_shapes = [
            var.get_shape().as_list()
            for var in params]
        self.size = sum([np.prod(shape) for shape in self.var_shapes])
        self.step = RawValue(ctypes.c_int, step)

        if opt_type == 'adam':
            self.ms = self.malloc_contiguous(self.size)
            self.vs = self.malloc_contiguous(self.size)
            self.lr = RawValue(ctypes.c_float, lr)
        elif opt_type == 'adamax':
            self.ms = self.malloc_contiguous(self.size)
            self.vs = self.malloc_contiguous(self.size)
            self.lr = RawValue(ctypes.c_float, lr)
        elif opt_type == 'rmsprop':
            self.vars = self.malloc_contiguous(self.size, np.ones(self.size, dtype=np.float))
        elif opt_type == 'momentum':
            self.vars = self.malloc_contiguous(self.size)
        else:
            self.vars = self.malloc_contiguous(self.size)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号