def _test_rosenbrock(self, constructor, old_fn):
params_t = torch.Tensor([1.5, 1.5])
state = {}
params = Variable(torch.Tensor([1.5, 1.5]), requires_grad=True)
optimizer = constructor([params])
solution = torch.Tensor([1, 1])
initial_dist = params.data.dist(solution)
def eval():
loss = rosenbrock(params)
loss.backward()
return loss
for i in range(2000):
optimizer.zero_grad()
optimizer.step(eval)
old_fn(lambda _: (rosenbrock(params_t), drosenbrock(params_t)),
params_t, state)
self.assertEqual(params.data, params_t)
self.assertLessEqual(params.data.dist(solution), initial_dist)
评论列表
文章目录