def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
self.var_list = var_list
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.scale_grad_by_procs = scale_grad_by_procs
size = sum(U.numel(v) for v in var_list)
self.m = np.zeros(size, 'float32')
self.v = np.zeros(size, 'float32')
self.t = 0
self.setfromflat = U.SetFromFlat(var_list)
self.getflat = U.GetFlat(var_list)
self.comm = MPI.COMM_WORLD if comm is None else comm
评论列表
文章目录