def main():
training_batch_size = 32
validation_batch_size = 32
epoch_num = 100
iter_freq_print_training_log = 100
iter_freq_validate = 500
lr = 1e-2
weight_decay = 1e-4
net = models.get_res152(num_classes=num_classes)
# net = get_res152(num_classes=num_classes, snapshot_path=os.path.join(ckpt_path, 'xxx.pth')).cuda()
net.train()
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.311, 0.340, 0.299], [0.167, 0.144, 0.138])
])
train_set = MultipleClassImageFolder(split_train_dir, transform)
train_loader = DataLoader(train_set, batch_size=training_batch_size, shuffle=True, num_workers=16)
val_set = MultipleClassImageFolder(split_val_dir, transform)
val_loader = DataLoader(val_set, batch_size=validation_batch_size, shuffle=True, num_workers=16)
criterion = nn.MultiLabelSoftMarginLoss().cuda()
optimizer = optim.SGD([
{'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias']},
{'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
'weight_decay': weight_decay}
], lr=lr, momentum=0.9, nesterov=True)
if not os.path.exists(ckpt_path):
os.mkdir(ckpt_path)
info = [1e9, 0, 0] # [best val loss, epoch, iter]
for epoch in range(0, epoch_num):
if epoch % 2 == 1:
optimizer.param_groups[1]['weight_decay'] = 0
print 'weight_decay is set to 0'
else:
optimizer.param_groups[1]['weight_decay'] = weight_decay
print 'weight_decay is set to %.4f' % weight_decay
train(train_loader, net, criterion, optimizer, epoch, iter_freq_print_training_log, iter_freq_validate,
val_loader, info)
评论列表
文章目录