def train(self, dataset):
self.model.train()
self.optimizer.zero_grad()
total_loss = 0.0
indices = torch.randperm(len(dataset))
for idx in tqdm(range(len(dataset)),desc='Training epoch ' + str(self.epoch + 1) + ''):
ltree, lsent, rtree, rsent, label = dataset[indices[idx]]
linput, rinput = Var(lsent), Var(rsent)
target = Var(map_label_to_target(label, dataset.num_classes))
if self.args.cuda:
linput, rinput = linput.cuda(), rinput.cuda()
target = target.cuda()
output = self.model(ltree, linput, rtree, rinput)
loss = self.criterion(output, target)
total_loss += loss.data[0]
loss.backward()
if idx % self.args.batchsize == 0 and idx > 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.epoch += 1
return total_loss / len(dataset)
# helper function for testing
评论列表
文章目录