def train_epoch(self, epoch):
self.model.train()
total_loss = 0
# since MSRVID doesn't have validation set, we manually leave-out some training data for validation
batches = math.ceil(len(self.train_loader.dataset.examples) / self.batch_size)
start_val_batch = math.floor(0.8 * batches)
left_out_val_a, left_out_val_b = [], []
left_out_val_ext_feats = []
left_out_val_labels = []
for batch_idx, batch in enumerate(self.train_loader):
# msrvid does not contain a validation set, we leave out some training data for validation to do model selection
if batch_idx >= start_val_batch:
left_out_val_a.append(batch.sentence_1)
left_out_val_b.append(batch.sentence_2)
left_out_val_ext_feats.append(batch.ext_feats)
left_out_val_labels.append(batch.label)
continue
self.optimizer.zero_grad()
output = self.model(batch.sentence_1, batch.sentence_2, batch.ext_feats)
loss = F.kl_div(output, batch.label)
total_loss += loss.data[0]
loss.backward()
self.optimizer.step()
if batch_idx % self.log_interval == 0:
self.logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, min(batch_idx * self.batch_size, len(batch.dataset.examples)),
len(batch.dataset.examples),
100. * batch_idx / (len(self.train_loader)), loss.data[0])
)
self.evaluate(self.train_evaluator, 'train')
if self.use_tensorboard:
self.writer.add_scalar('msrvid/train/kl_div_loss', total_loss, epoch)
return left_out_val_a, left_out_val_b, left_out_val_ext_feats, left_out_val_labels
评论列表
文章目录