def set_momentum(opt, momentum):
if isinstance(opt, optimizers.NesterovAG):
opt.momentum = momentum
return
if isinstance(opt, optimizers.MomentumSGD):
opt.momentum = momentum
return
if isinstance(opt, optimizers.SGD):
return
if isinstance(opt, optimizers.Adam):
opt.beta1 = momentum
return
raise NotImplementedError()
评论列表
文章目录