def get_optimizer(name, lr, momentum):
name = name.lower()
if name == "sgd":
return optimizers.SGD(lr=lr)
if name == "msgd":
return optimizers.MomentumSGD(lr=lr, momentum=momentum)
if name == "nesterov":
return optimizers.NesterovAG(lr=lr, momentum=momentum)
if name == "adam":
return optimizers.Adam(alpha=lr, beta1=momentum)
raise NotImplementedError()
评论列表
文章目录