msrvid_trainer.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:MP-CNN-Variants 作者: tuzhucheng 项目源码 文件源码
def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0

        # since MSRVID doesn't have validation set, we manually leave-out some training data for validation
        batches = math.ceil(len(self.train_loader.dataset.examples) / self.batch_size)
        start_val_batch = math.floor(0.8 * batches)
        left_out_val_a, left_out_val_b = [], []
        left_out_val_ext_feats = []
        left_out_val_labels = []

        for batch_idx, batch in enumerate(self.train_loader):
            # msrvid does not contain a validation set, we leave out some training data for validation to do model selection
            if batch_idx >= start_val_batch:
                left_out_val_a.append(batch.sentence_1)
                left_out_val_b.append(batch.sentence_2)
                left_out_val_ext_feats.append(batch.ext_feats)
                left_out_val_labels.append(batch.label)
                continue
            self.optimizer.zero_grad()
            output = self.model(batch.sentence_1, batch.sentence_2, batch.ext_feats)
            loss = F.kl_div(output, batch.label)
            total_loss += loss.data[0]
            loss.backward()
            self.optimizer.step()
            if batch_idx % self.log_interval == 0:
                self.logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, min(batch_idx * self.batch_size, len(batch.dataset.examples)),
                    len(batch.dataset.examples),
                    100. * batch_idx / (len(self.train_loader)), loss.data[0])
                )

        self.evaluate(self.train_evaluator, 'train')

        if self.use_tensorboard:
            self.writer.add_scalar('msrvid/train/kl_div_loss', total_loss, epoch)

        return left_out_val_a, left_out_val_b, left_out_val_ext_feats, left_out_val_labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号