def train(self, dataset):
self.model.train()
self.optimizer.zero_grad()
loss, k = 0.0, 0
indices = torch.randperm(len(dataset))
for idx in tqdm(range(len(dataset)), desc='Training epoch ' + str(self.epoch + 1) + ''):
ltree, lsent, rtree, rsent, label = dataset[indices[idx]]
linput, rinput = Var(lsent), Var(rsent)
target = Var(map_label_to_target(label, dataset.num_classes))
if self.args.cuda:
linput, rinput = linput.cuda(), rinput.cuda()
target = target.cuda()
output = self.model(ltree, linput, rtree, rinput)
err = self.criterion(output, target)
loss += err.data[0]
err.backward()
k += 1
if k % self.args.batchsize == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.epoch += 1
return loss / len(dataset)
# helper function for testing
评论列表
文章目录