def train(args):
# Setup Dataloader
data_loader = get_loader(args.dataset)
data_path = get_data_path(args.dataset)
loader = data_loader(data_path, is_transform=True, img_size=(args.img_rows, args.img_cols))
n_classes = loader.n_classes
trainloader = data.DataLoader(loader, batch_size=args.batch_size, num_workers=4, shuffle=True)
# Setup visdom for visualization
if args.visdom:
vis = visdom.Visdom()
loss_window = vis.line(X=torch.zeros((1,)).cpu(),
Y=torch.zeros((1)).cpu(),
opts=dict(xlabel='minibatches',
ylabel='Loss',
title='Training Loss',
legend=['Loss']))
# Setup Model
model = get_model(args.arch, n_classes)
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
model.cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=args.l_rate, momentum=0.99, weight_decay=5e-4)
for epoch in range(args.n_epoch):
for i, (images, labels) in enumerate(trainloader):
images = Variable(images.cuda())
labels = Variable(labels.cuda())
optimizer.zero_grad()
outputs = model(images)
loss = cross_entropy2d(outputs, labels)
loss.backward()
optimizer.step()
if args.visdom:
vis.line(
X=torch.ones((1, 1)).cpu() * i,
Y=torch.Tensor([loss.data[0]]).unsqueeze(0).cpu(),
win=loss_window,
update='append')
if (i+1) % 20 == 0:
print("Epoch [%d/%d] Loss: %.4f" % (epoch+1, args.n_epoch, loss.data[0]))
torch.save(model, "{}_{}_{}_{}.pkl".format(args.arch, args.dataset, args.feature_scale, epoch))
评论列表
文章目录