def train(train_loader, net, criterion, optimizer, curr_epoch, train_args, val_loader, visualize):
while True:
train_main_loss = AverageMeter()
train_aux_loss = AverageMeter()
curr_iter = (curr_epoch - 1) * len(train_loader)
for i, data in enumerate(train_loader):
optimizer.param_groups[0]['lr'] = 2 * train_args['lr'] * (1 - float(curr_iter) / train_args['max_iter']
) ** train_args['lr_decay']
optimizer.param_groups[1]['lr'] = train_args['lr'] * (1 - float(curr_iter) / train_args['max_iter']
) ** train_args['lr_decay']
inputs, gts, _ = data
assert len(inputs.size()) == 5 and len(gts.size()) == 4
inputs.transpose_(0, 1)
gts.transpose_(0, 1)
assert inputs.size()[3:] == gts.size()[2:]
slice_batch_pixel_size = inputs.size(1) * inputs.size(3) * inputs.size(4)
for inputs_slice, gts_slice in zip(inputs, gts):
inputs_slice = Variable(inputs_slice).cuda()
gts_slice = Variable(gts_slice).cuda()
optimizer.zero_grad()
outputs, aux = net(inputs_slice)
assert outputs.size()[2:] == gts_slice.size()[1:]
assert outputs.size()[1] == cityscapes.num_classes
main_loss = criterion(outputs, gts_slice)
aux_loss = criterion(aux, gts_slice)
loss = main_loss + 0.4 * aux_loss
loss.backward()
optimizer.step()
train_main_loss.update(main_loss.data[0], slice_batch_pixel_size)
train_aux_loss.update(aux_loss.data[0], slice_batch_pixel_size)
curr_iter += 1
writer.add_scalar('train_main_loss', train_main_loss.avg, curr_iter)
writer.add_scalar('train_aux_loss', train_aux_loss.avg, curr_iter)
writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter)
if (i + 1) % train_args['print_freq'] == 0:
print('[epoch %d], [iter %d / %d], [train main loss %.5f], [train aux loss %.5f]. [lr %.10f]' % (
curr_epoch, i + 1, len(train_loader), train_main_loss.avg, train_aux_loss.avg,
optimizer.param_groups[1]['lr']))
if curr_iter >= train_args['max_iter']:
return
if curr_iter % train_args['val_freq'] == 0:
validate(val_loader, net, criterion, optimizer, curr_epoch, i + 1, train_args, visualize)
curr_epoch += 1
评论列表
文章目录