def train(epoch):
for e_ in range(epoch):
if (e_ + 1) % 10 == 0:
adjust_learning_rate(optimizer, e_)
cnt = 0
loss = Variable(torch.Tensor([0]))
for i_q, i_w, i_e_p, i_a in zip(train_q, train_w, train_e_p, train_a):
cnt += 1
i_q = i_q.unsqueeze(0) # add dimension
probs = model.forward(i_q, i_w, i_e_p)
i_a = Variable(i_a)
curr_loss = loss_function(probs, i_a)
loss = torch.add(loss, torch.div(curr_loss, config.batch_size))
# naive batch implemetation, the lr is divided by batch size
if cnt % config.batch_size == 0:
print "Training loss", loss.data.sum()
loss.backward()
optimizer.step()
loss = Variable(torch.Tensor([0]))
model.zero_grad()
if cnt % config.valid_every == 0:
print "Accuracy:",eval()
评论列表
文章目录