def main():
if torch.cuda.is_available():
use_cuda = True
else: use_cuda = False
# Dataset
trainset = datasets.MNIST('../../data', download=True,train=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
# Model
model = Net()
# NLLLoss
nllloss = nn.NLLLoss() #CrossEntropyLoss = log_softmax + NLLLoss
# CenterLoss
loss_weight = 1.0
centerloss = CenterLoss(10,2,loss_weight)
if use_cuda:
nllloss = nllloss.cuda()
centerloss = centerloss.cuda()
model = model.cuda()
criterion = [nllloss, centerloss]
# optimzer4nn
optimizer4nn = optim.SGD(model.parameters(),lr=0.001,momentum=0.9, weight_decay=0.0005)
sheduler = lr_scheduler.StepLR(optimizer4nn,20,gamma=0.8)
# optimzer4center
optimzer4center = optim.SGD(centerloss.parameters(), lr =0.5)
for epoch in range(50):
sheduler.step()
# print optimizer4nn.param_groups[0]['lr']
train(train_loader, model, criterion, [optimizer4nn, optimzer4center], epoch+1, use_cuda)
MNIST_with_centerloss.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录