def init_state(self, param, state): xp = cuda.get_array_module(param.data) with cuda.get_device(param.data): state['m'] = xp.zeros_like(param.data) state['v'] = xp.zeros_like(param.data)