sick_trainer.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:MP-CNN-Variants 作者: tuzhucheng 项目源码 文件源码
def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0
        for batch_idx, batch in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            output = self.model(batch.sentence_1, batch.sentence_2, batch.ext_feats)
            loss = F.kl_div(output, batch.label)
            total_loss += loss.data[0]
            loss.backward()
            self.optimizer.step()
            if batch_idx % self.log_interval == 0:
                self.logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, min(batch_idx * self.batch_size, len(batch.dataset.examples)),
                    len(batch.dataset.examples),
                    100. * batch_idx / (len(self.train_loader)), loss.data[0])
                )

        if self.use_tensorboard:
            self.writer.add_scalar('sick/train/kl_div_loss', total_loss, epoch)

        return total_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号