rte_model.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:Recognizing-Textual-Entailment 作者: codedecde 项目源码 文件源码
def fit_batch(self, premise_batch, hypothesis_batch, y_batch):
        if not hasattr(self,'criterion'):
            self.criterion = nn.NLLLoss()
        if not hasattr(self, 'optimizer'):
            self.optimizer = optim.Adam(self.parameters(),  lr=self.options['LR'], betas=(0.9, 0.999), eps=1e-08, weight_decay=self.options['L2'])

        self.optimizer.zero_grad()
        preds = self.__call__(premise_batch, hypothesis_batch, training= True)
        loss = self.criterion(preds, y_batch)
        loss.backward()
        self.optimizer.step()

        _, pred_labels = torch.max(preds, dim=-1, keepdim = True)
        y_true = self._get_numpy_array_from_variable(y_batch)
        y_pred = self._get_numpy_array_from_variable(pred_labels)
        acc = accuracy_score(y_true, y_pred)

        ret_loss = self._get_numpy_array_from_variable(loss)[0]
        return ret_loss, acc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号