def get_optimizer(name, lr, momentum):
if name == "nesterov":
return optimizers.NesterovAG(lr=lr, momentum=momentum)
if name == "adam":
return optimizers.Adam(alpha=lr, beta1=momentum)
if name == "sgd":
return optimizers.SGD(lr=lr)
raise NotImplementationError()
评论列表
文章目录