torch_utils.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pytorch-trpo 作者: mjacar 项目源码 文件源码
def fit(self, observations, labels):
    def closure():
      predicted = self.predict(observations)
      loss = self.loss_fn(predicted, labels)
      self.optimizer.zero_grad()
      loss.backward()
      return loss
    old_params = parameters_to_vector(self.model.parameters())
    for lr in self.lr * .5**np.arange(10):
      self.optimizer = optim.LBFGS(self.model.parameters(), lr=lr)
      self.optimizer.step(closure)
      current_params = parameters_to_vector(self.model.parameters())
      if any(np.isnan(current_params.data.cpu().numpy())):
        print("LBFGS optimization diverged. Rolling back update...")
        vector_to_parameters(old_params, self.model.parameters())
      else:
        return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号