def update_one_gpu(self, param, state): cuda.elementwise('T grad, T lr', 'T param', 'param -= lr * grad', 'sgd')(param.grad, self.lr, param.data)