def train(train_iter, dev_iter, model, args):
if args.cuda:
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
steps = 0
model.train()
for epoch in range(1, args.epochs+1):
for batch in train_iter:
feature, target = batch.text, batch.label
feature.data.t_(), target.data.sub_(1) # batch first, index align
if args.cuda:
feature, target = feature.cuda(), target.cuda()
optimizer.zero_grad()
logit = model(feature)
#print('logit vector', logit.size())
#print('target vector', target.size())
loss = F.cross_entropy(logit, target)
loss.backward()
optimizer.step()
steps += 1
if steps % args.log_interval == 0:
corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
accuracy = 100.0 * corrects/batch.batch_size
sys.stdout.write(
'\rBatch[{}] - loss: {:.6f} acc: {:.4f}%({}/{})'.format(steps,
loss.data[0],
accuracy,
corrects,
batch.batch_size))
if steps % args.test_interval == 0:
eval(dev_iter, model, args)
if steps % args.save_interval == 0:
if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir)
save_prefix = os.path.join(args.save_dir, 'snapshot')
save_path = '{}_steps{}.pt'.format(save_prefix, steps)
torch.save(model, save_path)
train.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录