def get_updates(nnet,
train_obj,
trainable_params):
implemented_solvers = ("nesterov", "adagrad", "adadelta", "adam")
if not hasattr(nnet, "solver") or nnet.solver not in implemented_solvers:
nnet.sgd_solver = "nesterov"
else:
nnet.sgd_solver = nnet.solver
if nnet.sgd_solver == "nesterov":
updates = l_updates.nesterov_momentum(train_obj,
trainable_params,
learning_rate=Cfg.learning_rate,
momentum=0.9)
elif nnet.sgd_solver == "adagrad":
updates = l_updates.adagrad(train_obj,
trainable_params,
learning_rate=Cfg.learning_rate)
elif nnet.sgd_solver == "adadelta":
updates = l_updates.adadelta(train_obj,
trainable_params,
learning_rate=Cfg.learning_rate)
elif nnet.sgd_solver == "adam":
updates = l_updates.adam(train_obj,
trainable_params,
learning_rate=Cfg.learning_rate)
return updates
评论列表
文章目录