def train_epoch(self, epoch):
self.model.train()
total_loss = 0
for batch_idx, batch in enumerate(self.train_loader):
self.optimizer.zero_grad()
output = self.model(batch.sentence_1, batch.sentence_2, batch.ext_feats)
loss = F.kl_div(output, batch.label)
total_loss += loss.data[0]
loss.backward()
self.optimizer.step()
if batch_idx % self.log_interval == 0:
self.logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, min(batch_idx * self.batch_size, len(batch.dataset.examples)),
len(batch.dataset.examples),
100. * batch_idx / (len(self.train_loader)), loss.data[0])
)
if self.use_tensorboard:
self.writer.add_scalar('sick/train/kl_div_loss', total_loss, epoch)
return total_loss
评论列表
文章目录