def train(self, dataset):
self.model.train()
self.embedding_model.train()
self.embedding_model.zero_grad()
self.optimizer.zero_grad()
loss, k = 0.0, 0
# torch.manual_seed(789)
indices = torch.randperm(len(dataset))
for idx in tqdm(range(len(dataset)),desc='Training epoch '+str(self.epoch+1)+''):
tree, sent, label = dataset[indices[idx]]
input = Var(sent)
target = Var(torch.LongTensor([int(label)]))
if self.args.cuda:
input = input.cuda()
target = target.cuda()
emb = F.torch.unsqueeze(self.embedding_model(input), 1)
output, err, _, _ = self.model.forward(tree, emb, training=True)
#params = self.model.childsumtreelstm.getParameters()
# params_norm = params.norm()
err = err/self.args.batchsize # + 0.5*self.args.reg*params_norm*params_norm # custom bias
loss += err.data[0] #
err.backward()
k += 1
if k==self.args.batchsize:
for f in self.embedding_model.parameters():
f.data.sub_(f.grad.data * self.args.emblr)
self.optimizer.step()
self.embedding_model.zero_grad()
self.optimizer.zero_grad()
k = 0
self.epoch += 1
return loss/len(dataset)
# helper function for testing
评论列表
文章目录