def train(self, dataset):
self.model.train()
self.optimizer.zero_grad()
loss, k = 0.0, 0
indices = torch.randperm(len(dataset))
for idx in tqdm(xrange(len(dataset)),desc='Training epoch '+str(self.epoch+1)+''):
ltree,lsent,rtree,rsent,label = dataset[indices[idx]]
linput, rinput = Var(lsent), Var(rsent)
target = Var(map_label_to_target(label,dataset.num_classes))
if self.args.cuda:
linput, rinput = linput.cuda(), rinput.cuda()
target = target.cuda()
output = self.model(ltree,linput,rtree,rinput)
err = self.criterion(output, target)
loss += err.data[0]
err.backward()
k += 1
if k%self.args.batchsize==0:
self.optimizer.step()
self.optimizer.zero_grad()
self.epoch += 1
return loss/len(dataset)
# helper function for testing
评论列表
文章目录