ENet-train.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:PytorchDL 作者: FredHuangBia 项目源码 文件源码
def __init__(self, model, criterion, opt, optimState):
        self.model = model
        self.criterion = criterion
        self.optimState = optimState
        if self.optimState == None:
            self.optimState = { 'learningRate' : opt.LR,
                                'learningRateDecay' : opt.LRDParam,
                                'momentum' : opt.momentum,
                                'nesterov' : False,
                                'dampening'  : opt.dampening,
                                'weightDecay' : opt.weightDecay
                            }
        self.opt = opt
        if opt.optimizer == 'SGD':
            self.optimizer = optim.SGD(model.parameters(), lr=opt.LR, momentum=opt.momentum, dampening=opt.dampening, weight_decay=opt.weightDecay)
        elif opt.optimizer == 'Adam':
            self.optimizer = optim.Adam(model.parameters(), lr=opt.LR, betas=(opt.momentum, 0.999), eps=1e-8, weight_decay=opt.weightDecay)

        self.logger = { 'train' : open(os.path.join(opt.resume, 'train.log'), 'a+'), 
                        'val' : open(os.path.join(opt.resume, 'val.log'), 'a+')
                    }
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号