def init_state(self, param, state):
xp = cuda.get_array_module(param.data)
with cuda.get_device(param.data):
state['n'] = xp.zeros_like(param.data)
state['g'] = xp.zeros_like(param.data)
state['delta'] = xp.zeros_like(param.data)
评论列表
文章目录