def fit_batch(self, premise_batch, hypothesis_batch, y_batch):
if not hasattr(self,'criterion'):
self.criterion = nn.NLLLoss()
if not hasattr(self, 'optimizer'):
self.optimizer = optim.Adam(self.parameters(), lr=self.options['LR'], betas=(0.9, 0.999), eps=1e-08, weight_decay=self.options['L2'])
self.optimizer.zero_grad()
preds = self.__call__(premise_batch, hypothesis_batch, training= True)
loss = self.criterion(preds, y_batch)
loss.backward()
self.optimizer.step()
_, pred_labels = torch.max(preds, dim=-1, keepdim = True)
y_true = self._get_numpy_array_from_variable(y_batch)
y_pred = self._get_numpy_array_from_variable(pred_labels)
acc = accuracy_score(y_true, y_pred)
ret_loss = self._get_numpy_array_from_variable(loss)[0]
return ret_loss, acc
rte_model.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录