def __init__(self, num_actions, alg_type, opt_type = None, lr = 0):
# Net
if alg_type in ['q', 'sarsa']:
self.var_shapes = [(8, 8, 4, 16),
(16),
(4, 4, 16, 32),
(32),
(2592, 256), #(3872, 256) if PADDING = "SAME"
(256),
(256, num_actions),
(num_actions)]
self.size = 0
for shape in self.var_shapes:
self.size += np.prod(shape)
if opt_type == "adam":
self.ms = self.malloc_contiguous(self.size)
self.vs = self.malloc_contiguous(self.size)
self.lr = RawValue(ctypes.c_float, lr)
elif opt_type == "rmsprop":
self.vars = self.malloc_contiguous(self.size, np.ones(self.size, dtype=np.float))
else: #momentum
self.vars = self.malloc_contiguous(self.size)
else:
# no lstm
self.var_shapes = [(8, 8, 4, 16),
(16),
(4, 4, 16, 32),
(32),
(2592, 256), #(3872, 256)
(256),
(256, num_actions),
(num_actions),
(256, 1),
(1)]
self.size = 0
for shape in self.var_shapes:
self.size += np.prod(shape)
if opt_type == "adam":
self.ms = self.malloc_contiguous(self.size)
self.vs = self.malloc_contiguous(self.size)
self.lr = RawValue(ctypes.c_float, lr)
if opt_type == "rmsprop":
self.vars = self.malloc_contiguous(self.size, np.ones(self.size, dtype=np.float))
elif opt_type == "momentum":
self.vars = self.malloc_contiguous(self.size)
else:
self.vars = self.malloc_contiguous(self.size)
评论列表
文章目录