def __init__(self, model, criterion, opt, optimState):
self.model = model
self.criterion = criterion
self.optimState = optimState
if self.optimState == None:
self.optimState = { 'learningRate' : opt.LR,
'learningRateDecay' : opt.LRDParam,
'momentum' : opt.momentum,
'nesterov' : False,
'dampening' : opt.dampening,
'weightDecay' : opt.weightDecay
}
self.opt = opt
if opt.optimizer == 'SGD':
self.optimizer = optim.SGD(model.parameters(), lr=opt.LR, momentum=opt.momentum, dampening=opt.dampening, weight_decay=opt.weightDecay)
elif opt.optimizer == 'Adam':
self.optimizer = optim.Adam(model.parameters(), lr=opt.LR, betas=(opt.momentum, 0.999), eps=1e-8, weight_decay=opt.weightDecay)
self.logger = { 'train' : open(os.path.join(opt.resume, 'train.log'), 'a+'),
'val' : open(os.path.join(opt.resume, 'val.log'), 'a+')
}
评论列表
文章目录