def init_state(self, param, state): data = param.data xp = cuda.get_array_module(data) with cuda.get_device(data): state['msg'] = xp.zeros_like(data) state['msdx'] = xp.zeros_like(data)