train_planet.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:kaggle-planet 作者: ZijunDeng 项目源码 文件源码
def main():
    training_batch_size = 32
    validation_batch_size = 32
    epoch_num = 100
    iter_freq_print_training_log = 100
    iter_freq_validate = 500
    lr = 1e-2
    weight_decay = 1e-4

    net = models.get_res152(num_classes=num_classes)
    # net = get_res152(num_classes=num_classes, snapshot_path=os.path.join(ckpt_path, 'xxx.pth')).cuda()
    net.train()

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.311, 0.340, 0.299], [0.167, 0.144, 0.138])
    ])

    train_set = MultipleClassImageFolder(split_train_dir, transform)
    train_loader = DataLoader(train_set, batch_size=training_batch_size, shuffle=True, num_workers=16)
    val_set = MultipleClassImageFolder(split_val_dir, transform)
    val_loader = DataLoader(val_set, batch_size=validation_batch_size, shuffle=True, num_workers=16)

    criterion = nn.MultiLabelSoftMarginLoss().cuda()
    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'weight_decay': weight_decay}
    ], lr=lr, momentum=0.9, nesterov=True)

    if not os.path.exists(ckpt_path):
        os.mkdir(ckpt_path)

    info = [1e9, 0, 0]  # [best val loss, epoch, iter]

    for epoch in range(0, epoch_num):
        if epoch % 2 == 1:
            optimizer.param_groups[1]['weight_decay'] = 0
            print 'weight_decay is set to 0'
        else:
            optimizer.param_groups[1]['weight_decay'] = weight_decay
            print 'weight_decay is set to %.4f' % weight_decay
        train(train_loader, net, criterion, optimizer, epoch, iter_freq_print_training_log, iter_freq_validate,
              val_loader, info)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号