MNIST_with_centerloss.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:MNIST_center_loss_pytorch 作者: jxgu1016 项目源码 文件源码
def main():
    if torch.cuda.is_available():
        use_cuda = True
    else: use_cuda = False
    # Dataset
    trainset = datasets.MNIST('../../data', download=True,train=True, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))]))
    train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)

    # Model
    model = Net()

    # NLLLoss
    nllloss = nn.NLLLoss() #CrossEntropyLoss = log_softmax + NLLLoss
    # CenterLoss
    loss_weight = 1.0
    centerloss = CenterLoss(10,2,loss_weight)
    if use_cuda:
        nllloss = nllloss.cuda()
        centerloss = centerloss.cuda()
        model = model.cuda()
    criterion = [nllloss, centerloss]

    # optimzer4nn
    optimizer4nn = optim.SGD(model.parameters(),lr=0.001,momentum=0.9, weight_decay=0.0005)
    sheduler = lr_scheduler.StepLR(optimizer4nn,20,gamma=0.8)

    # optimzer4center
    optimzer4center = optim.SGD(centerloss.parameters(), lr =0.5)

    for epoch in range(50):
        sheduler.step()
        # print optimizer4nn.param_groups[0]['lr']
        train(train_loader, model, criterion, [optimizer4nn, optimzer4center], epoch+1, use_cuda)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号