def get_optimizer(args, params):
if args.dataset == 'mnist':
if args.model == 'optnet-eq':
params = list(params)
A_param = params.pop(0)
assert(A_param.size() == (args.neq, args.nHidden))
optimizer = optim.Adam([
{'params': params, 'lr': 1e-3},
{'params': [A_param], 'lr': 1e-1}
])
else:
optimizer = optim.Adam(params)
elif args.dataset in ('cifar-10', 'cifar-100'):
if args.opt == 'sgd':
optimizer = optim.SGD(params, lr=1e-1, momentum=0.9, weight_decay=args.weightDecay)
elif args.opt == 'adam':
optimizer = optim.Adam(params, weight_decay=args.weightDecay)
else:
assert(False)
return optimizer
评论列表
文章目录