def __init__(self, params, opt_type=None, lr=0, step=0):
self.var_shapes = [
var.get_shape().as_list()
for var in params]
self.size = sum([np.prod(shape) for shape in self.var_shapes])
self.step = RawValue(ctypes.c_int, step)
if opt_type == 'adam':
self.ms = self.malloc_contiguous(self.size)
self.vs = self.malloc_contiguous(self.size)
self.lr = RawValue(ctypes.c_float, lr)
elif opt_type == 'adamax':
self.ms = self.malloc_contiguous(self.size)
self.vs = self.malloc_contiguous(self.size)
self.lr = RawValue(ctypes.c_float, lr)
elif opt_type == 'rmsprop':
self.vars = self.malloc_contiguous(self.size, np.ones(self.size, dtype=np.float))
elif opt_type == 'momentum':
self.vars = self.malloc_contiguous(self.size)
else:
self.vars = self.malloc_contiguous(self.size)
评论列表
文章目录