def get_optimizer(name, lr, momentum=0.9):
if name.lower() == "adam":
return chainer.optimizers.Adam(alpha=lr, beta1=momentum)
if name.lower() == "eve":
return Eve(alpha=lr, beta1=momentum)
if name.lower() == "adagrad":
return chainer.optimizers.AdaGrad(lr=lr)
if name.lower() == "adadelta":
return chainer.optimizers.AdaDelta(rho=momentum)
if name.lower() == "nesterov" or name.lower() == "nesterovag":
return chainer.optimizers.NesterovAG(lr=lr, momentum=momentum)
if name.lower() == "rmsprop":
return chainer.optimizers.RMSprop(lr=lr, alpha=momentum)
if name.lower() == "momentumsgd":
return chainer.optimizers.MomentumSGD(lr=lr, mommentum=mommentum)
if name.lower() == "sgd":
return chainer.optimizers.SGD(lr=lr)
raise Exception()
评论列表
文章目录